21 Commits

Author SHA1 Message Date
maxiao1
46da95569f Merge branch 'v0.5.4_dev_linhai' into 'v0.5.4_dev'
V0.5.4 dev linhai

See merge request OpenDAS/sglang!9
2025-11-04 12:04:48 +00:00
linhai1
ee77577211 V0.5.4 dev linhai 2025-11-04 12:04:47 +00:00
maxiao1
a9e0e668c4 Merge branch 'v0.5.4_dev_maxiao' into 'v0.5.4_dev'
enable custom_allreduce

See merge request OpenDAS/sglang!6
2025-11-04 09:33:25 +00:00
maxiao1
0c5532b0c1 enable custom_allreduce 2025-11-04 09:33:24 +00:00
maxiao1
785e5e900b Merge branch 'v0.5.4_dev_maxiao' into 'v0.5.4_dev'
调用vllm里custom all reduce

See merge request OpenDAS/sglang!5
2025-11-03 08:30:29 +00:00
maxiao
d2fdeac22f 调用vllm里custom all reduce 2025-11-03 16:28:21 +08:00
maxiao1
47e4d92348 Merge branch 'v0.5.4_dev_qwennext' into 'v0.5.4_dev'
适配qwen3-next

See merge request OpenDAS/sglang!4
2025-11-03 05:18:16 +00:00
maxiao1
0fbecc4364 Merge branch 'v0.5.4_dev_maxiao' into 'v0.5.4_dev'
change sgl_kernel WARP_SIZE to 64

See merge request OpenDAS/sglang!3
2025-11-03 03:33:05 +00:00
maxiao1
75cd34d172 change sgl_kernel WARP_SIZE to 64 2025-11-03 10:17:53 +08:00
maxiao
477fddf28d 适配qwen3-next 2025-10-30 18:03:07 +08:00
maxiao1
8fc552638f Merge branch 'v0.5.4_dev_maxiao' into 'v0.5.4_dev'
适配w8a8模型

See merge request OpenDAS/sglang!1
2025-10-29 02:09:59 +00:00
maxiao1
eb4ba1c295 update UNBALANCED_MODEL_LOADING_TIMEOUT_S=3600 2025-10-29 10:06:23 +08:00
maxiao1
4b9b337b39 适配w8a8模型 2025-10-29 09:06:22 +08:00
lizhigong
f6528b74be 增加hipprof支持、修复异步调度中的同步问题 2025-10-28 16:25:06 +08:00
maxiao1
a5718531b7 关闭custom_allreduce保持正确性 2025-10-28 10:57:25 +08:00
guobj
c333f12547 补充 bench_serving.py里tpot等指标 2025-10-28 02:11:36 +00:00
maxiao
f9a026ad2b fix fused_add_rms_norm bug 2025-10-27 10:27:57 +08:00
maxiao1
b80ae5e9ff adaptation w4a8 tp 2025-10-25 16:33:07 +08:00
lizhigong
b091a7a5c9 adapt w4a8 marlin deepep dp ep
(cherry picked from commit a0fb70e9c1)
2025-10-25 15:07:57 +08:00
lizhigong
143ec5f36c adaptation w4A8 quantization
(cherry picked from commit 848c5b8290)
2025-10-25 15:07:04 +08:00
lizhigong
67510e0172 adaptation part w4A8 quantization
(cherry picked from commit 68277eac30)
2025-10-25 15:06:27 +08:00
32 changed files with 2306 additions and 64 deletions

View File

@@ -839,10 +839,12 @@ class BenchmarkMetrics:
mean_ttft_ms: float mean_ttft_ms: float
median_ttft_ms: float median_ttft_ms: float
std_ttft_ms: float std_ttft_ms: float
p95_ttft_ms: float
p99_ttft_ms: float p99_ttft_ms: float
mean_tpot_ms: float mean_tpot_ms: float
median_tpot_ms: float median_tpot_ms: float
std_tpot_ms: float std_tpot_ms: float
p95_tpot_ms: float
p99_tpot_ms: float p99_tpot_ms: float
mean_itl_ms: float mean_itl_ms: float
median_itl_ms: float median_itl_ms: float
@@ -1665,10 +1667,12 @@ def calculate_metrics(
* 1000, # ttfts is empty if streaming is not supported by backend * 1000, # ttfts is empty if streaming is not supported by backend
median_ttft_ms=np.median(ttfts or 0) * 1000, median_ttft_ms=np.median(ttfts or 0) * 1000,
std_ttft_ms=np.std(ttfts or 0) * 1000, std_ttft_ms=np.std(ttfts or 0) * 1000,
p95_ttft_ms=np.percentile(ttfts or 0, 95) * 1000,
p99_ttft_ms=np.percentile(ttfts or 0, 99) * 1000, p99_ttft_ms=np.percentile(ttfts or 0, 99) * 1000,
mean_tpot_ms=np.mean(tpots or 0) * 1000, mean_tpot_ms=np.mean(tpots or 0) * 1000,
median_tpot_ms=np.median(tpots or 0) * 1000, median_tpot_ms=np.median(tpots or 0) * 1000,
std_tpot_ms=np.std(tpots or 0) * 1000, std_tpot_ms=np.std(tpots or 0) * 1000,
p95_tpot_ms=np.percentile(tpots or 0, 95) * 1000,
p99_tpot_ms=np.percentile(tpots or 0, 99) * 1000, p99_tpot_ms=np.percentile(tpots or 0, 99) * 1000,
mean_itl_ms=np.mean(itls or 0) * 1000, mean_itl_ms=np.mean(itls or 0) * 1000,
median_itl_ms=np.median(itls or 0) * 1000, median_itl_ms=np.median(itls or 0) * 1000,
@@ -1974,6 +1978,12 @@ async def benchmark(
print("{:<40} {:<10.2f}".format("Mean TTFT (ms):", metrics.mean_ttft_ms)) print("{:<40} {:<10.2f}".format("Mean TTFT (ms):", metrics.mean_ttft_ms))
print("{:<40} {:<10.2f}".format("Median TTFT (ms):", metrics.median_ttft_ms)) print("{:<40} {:<10.2f}".format("Median TTFT (ms):", metrics.median_ttft_ms))
print("{:<40} {:<10.2f}".format("P99 TTFT (ms):", metrics.p99_ttft_ms)) print("{:<40} {:<10.2f}".format("P99 TTFT (ms):", metrics.p99_ttft_ms))
print("{:<40} {:<10.2f}".format("P95 TTFT (ms):", metrics.p95_ttft_ms))
print("{s:{c}^{n}}".format(s="Time per Output Token (excl. 1st token)", n=50, c="-"))
print("{:<40} {:<10.2f}".format("Mean TPOT (ms):", metrics.mean_tpot_ms))
print("{:<40} {:<10.2f}".format("Median TPOT (ms):", metrics.median_tpot_ms))
print("{:<40} {:<10.2f}".format("P99 TPOT (ms):", metrics.p99_tpot_ms))
print("{:<40} {:<10.2f}".format("P95 TPOT (ms):", metrics.p95_tpot_ms))
print("{s:{c}^{n}}".format(s="Inter-Token Latency", n=50, c="-")) print("{s:{c}^{n}}".format(s="Inter-Token Latency", n=50, c="-"))
print("{:<40} {:<10.2f}".format("Mean ITL (ms):", metrics.mean_itl_ms)) print("{:<40} {:<10.2f}".format("Mean ITL (ms):", metrics.mean_itl_ms))
print("{:<40} {:<10.2f}".format("Median ITL (ms):", metrics.median_itl_ms)) print("{:<40} {:<10.2f}".format("Median ITL (ms):", metrics.median_itl_ms))

View File

@@ -5,11 +5,23 @@ from typing import List, Optional, Tuple
import torch import torch
from sglang.srt.utils import get_bool_env_var, is_hip, is_hpu, is_npu from sglang.srt.utils import get_bool_env_var, is_hip, is_hpu, is_npu
try:
from lmslim import quant_ops
from lmslim import quant_tools
except Exception:
print("INFO: Please install lmslim if you want to infer gptq or awq or w8a8 model.\n")
try:
import lightop
except Exception:
print("INFO: Please install lightop if you want to infer awq of marlin.\n")
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
use_vllm_custom_allreduce = get_bool_env_var( use_vllm_custom_allreduce = get_bool_env_var(
"USE_VLLM_CUSTOM_ALLREDUCE", default="false" "USE_VLLM_CUSTOM_ALLREDUCE", default="false"
) )
use_dcu_custom_allreduce= get_bool_env_var(
"USE_DCU_CUSTOM_ALLREDUCE", default="false"
)
if not is_hpu(): if not is_hpu():
# ROCm does not use vllm custom allreduce # ROCm does not use vllm custom allreduce
@@ -24,6 +36,11 @@ if not is_hpu():
except ImportError as e: except ImportError as e:
logger.warning("Failed to import from custom_ar with %r", e) logger.warning("Failed to import from custom_ar with %r", e)
if use_dcu_custom_allreduce:
try:
import vllm._C
except ImportError as e:
logger.warning("Failed to import from vllm._C with %r", e)
if not is_hip() and not is_npu(): if not is_hip() and not is_npu():
if use_vllm_custom_allreduce: if use_vllm_custom_allreduce:
@@ -66,8 +83,79 @@ if not is_hip() and not is_npu():
) -> None: ) -> None:
custom_op.register_graph_buffers(fa, handles, offsets) custom_op.register_graph_buffers(fa, handles, offsets)
elif is_hip and use_dcu_custom_allreduce:
# custom ar
def init_custom_ar(ipc_tensors: list[torch.Tensor], rank_data: torch.Tensor,
rank: int, fully_connected: bool) -> int:
return torch.ops._C_custom_ar.init_custom_ar(ipc_tensors, rank_data, rank,
fully_connected)
def all_reduce(fa: int, inp: torch.Tensor, out: torch.Tensor, reg_buffer: int,
reg_buffer_sz_bytes: int) -> None:
torch.ops._C_custom_ar.all_reduce(fa, inp, out, reg_buffer,
reg_buffer_sz_bytes)
def dispose(fa: int) -> None:
torch.ops._C_custom_ar.dispose(fa)
def meta_size() -> int:
return torch.ops._C_custom_ar.meta_size()
def register_buffer(fa: int, ipc_tensors: list[int]) -> None:
return torch.ops._C_custom_ar.register_buffer(fa, ipc_tensors)
def get_graph_buffer_ipc_meta(fa: int) -> tuple[list[int], list[int]]:
return torch.ops._C_custom_ar.get_graph_buffer_ipc_meta(fa)
def register_graph_buffers(fa: int, handles: list[list[int]],
offsets: list[list[int]]) -> None:
torch.ops._C_custom_ar.register_graph_buffers(fa, handles, offsets)
def allocate_shared_buffer_and_handle(size: int) -> tuple[int, torch.Tensor]:
return torch.ops._C_custom_ar.allocate_shared_buffer_and_handle(size)
def open_mem_handle(mem_handle: torch.Tensor):
return torch.ops._C_custom_ar.open_mem_handle(mem_handle)
def free_shared_buffer(ptr: int) -> None:
torch.ops._C_custom_ar.free_shared_buffer(ptr)
def read_cache(
keys: torch.Tensor,
values: torch.Tensor,
key_caches: list[torch.Tensor],
value_caches: list[torch.Tensor],
slot_mapping: torch.Tensor,
kv_cache_dtype: str
) -> None:
torch.ops._C_cache_ops.read_cache(keys, values, key_caches,
value_caches, slot_mapping,
kv_cache_dtype)
def write_cache_multi_layers(
keys: torch.Tensor,
values: torch.Tensor,
key_caches: list[torch.Tensor],
value_caches: list[torch.Tensor],
slot_mapping: torch.Tensor,
kv_cache_dtype: str
) -> None:
torch.ops._C_cache_ops.write_cache_multi_layers(keys, values, key_caches,
value_caches, slot_mapping,
kv_cache_dtype)
else: else:
# ROCM custom allreduce # sgl_kernel ROCM custom allreduce
def init_custom_ar( def init_custom_ar(
meta: torch.Tensor, meta: torch.Tensor,
@@ -175,3 +263,25 @@ def mscclpp_allreduce(
context: int, inp: torch.Tensor, out: torch.Tensor, nthreads: int, nblocks: int context: int, inp: torch.Tensor, out: torch.Tensor, nthreads: int, nblocks: int
) -> None: ) -> None:
return sgl_kernel.allreduce.mscclpp_allreduce(context, inp, out, nthreads, nblocks) return sgl_kernel.allreduce.mscclpp_allreduce(context, inp, out, nthreads, nblocks)
def triton_scaled_mm(a: torch.Tensor,
b: torch.Tensor,
scale_a: torch.Tensor,
scale_b: torch.Tensor,
out_dtype: torch.dtype,
bias: Optional[torch.Tensor] = None,
best_config:Optional[list] = None) -> torch.Tensor:
return quant_ops.triton_scaled_mm(a, b,scale_a,scale_b,out_dtype,bias,best_config)
def triton_int8_gemm_helper(m: int,
n: int,
k: int,
per_token_act_quant: bool,
per_out_channel_weight_quant: bool,
use_bias: bool,
out_dtype: type[torch.dtype] = torch.float16,
device: str = "cuda:0",
best_config:Optional[list] = None,
repeat:Optional[int] = 2):
return quant_tools.triton_int8_gemm_helper(m,n,k,per_token_act_quant,per_out_channel_weight_quant,use_bias,out_dtype,device,best_config,repeat)

View File

@@ -614,6 +614,8 @@ class ModelConfig:
"petit_nvfp4", "petit_nvfp4",
"quark", "quark",
"mxfp4", "mxfp4",
"slimquant_w4a8_marlin",
"w8a8_int8",
] ]
optimized_quantization_methods = [ optimized_quantization_methods = [
"fp8", "fp8",
@@ -633,6 +635,7 @@ class ModelConfig:
"qoq", "qoq",
"w4afp8", "w4afp8",
"petit_nvfp4", "petit_nvfp4",
"slimquant_w4a8_marlin",
] ]
compatible_quantization_methods = { compatible_quantization_methods = {
"modelopt_fp4": ["modelopt"], "modelopt_fp4": ["modelopt"],

View File

@@ -30,6 +30,8 @@ try:
if ops.use_vllm_custom_allreduce and not _is_hip: if ops.use_vllm_custom_allreduce and not _is_hip:
# Use vLLM custom allreduce # Use vLLM custom allreduce
ops.meta_size() ops.meta_size()
elif ops.use_dcu_custom_allreduce:
ops.meta_size()
else: else:
# Use custom allreduce from sgl kernel (ROCM and TRT-LLM) # Use custom allreduce from sgl kernel (ROCM and TRT-LLM)
import sgl_kernel # noqa: F401 import sgl_kernel # noqa: F401
@@ -419,3 +421,274 @@ class CustomAllreduce:
def __del__(self): def __del__(self):
self.close() self.close()
class DCUCustomAllreduce:
_SUPPORTED_WORLD_SIZES = [2, 4, 6, 8, 16]
# max_size: max supported allreduce size
def __init__(self,
group: ProcessGroup,
device: Union[int, str, torch.device],
max_size=8192 * 512) -> None:
"""
Args:
group: the process group to work on. If None, it will use the
default process group.
device: the device to bind the CustomAllreduce to. If None,
it will be bind to f"cuda:{local_rank}".
It is the caller's responsibility to make sure each communicator
is bind to a unique device, and all communicators in this group
are in the same node.
"""
self._IS_CAPTURING = False
self.disabled = True
if not custom_ar:
# disable because of missing custom allreduce library
# e.g. in a non-GPU environment
logger.info("Custom allreduce is disabled because "
"of missing custom allreduce library")
return
self.group = group
assert dist.get_backend(group) != dist.Backend.NCCL, (
"CustomAllreduce should be attached to a non-NCCL group.")
if not all(in_the_same_node_as(group, source_rank=0)):
# No need to initialize custom allreduce for multi-node case.
logger.warning(
"Custom allreduce is disabled because this process group"
" spans across nodes.")
return
rank = dist.get_rank(group=self.group)
self.rank = rank
world_size = dist.get_world_size(group=self.group)
# if world_size > envs.VLLM_CUSTOM_ALLREDUCE_SUPPORTED_WORLDSIZE_MAX:
if world_size > 16:
return
if world_size == 1:
# No need to initialize custom allreduce for single GPU case.
return
if world_size not in CustomAllreduce._SUPPORTED_WORLD_SIZES:
logger.warning(
"Custom allreduce is disabled due to an unsupported world"
" size: %d. Supported world sizes: %s. To silence this "
"warning, specify disable_custom_all_reduce=True explicitly.",
world_size, str(CustomAllreduce._SUPPORTED_WORLD_SIZES))
return
if isinstance(device, int):
device = torch.device(f"cuda:{device}")
elif isinstance(device, str):
device = torch.device(device)
# now `device` is a `torch.device` object
assert isinstance(device, torch.device)
self.device = device
cuda_visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES", None)
if cuda_visible_devices:
device_ids = list(map(int, cuda_visible_devices.split(",")))
else:
device_ids = list(range(torch.cuda.device_count()))
physical_device_id = device_ids[device.index]
tensor = torch.tensor([physical_device_id],
dtype=torch.int,
device="cpu")
gather_list = [
torch.tensor([0], dtype=torch.int, device="cpu")
for _ in range(world_size)
]
dist.all_gather(gather_list, tensor, group=self.group)
physical_device_ids = [t.item() for t in gather_list]
# test nvlink first, this will filter out most of the cases
# where custom allreduce is not supported
# this checks hardware and driver support for NVLink
# assert current_platform.is_cuda_alike()
# fully_connected = current_platform.is_fully_connected(
# physical_device_ids)
if _is_cuda or _is_hip:
fully_connected = is_full_nvlink(physical_device_ids, world_size)
# if world_size > 2 and not fully_connected:
if not fully_connected:
max_size = 32 * 8192 * 2
# if not envs.VLLM_PCIE_USE_CUSTOM_ALLREDUCE:
# logger.warning(
# "Custom allreduce is disabled because it's not supported on"
# " more than two PCIe-only GPUs. To silence this warning, "
# "specify disable_custom_all_reduce=True explicitly.")
# return
logger.warning(
"We are using PCIe's custom allreduce."
"If the performance is poor, we can add "
"--disable-custom-all-reduce in the instruction.")
# test P2P capability, this checks software/cudaruntime support
# this is expensive to compute at the first time
# then we cache the result
# On AMD GPU, p2p is always enabled between XGMI connected GPUs
if not _is_hip and not _can_p2p(rank, world_size):
logger.warning(
"Custom allreduce is disabled because your platform lacks "
"GPU P2P capability or P2P test failed. To silence this "
"warning, specify disable_custom_all_reduce=True explicitly.")
return
self.disabled = False
# Buffers memory are owned by this Python class and passed to C++.
# Meta data composes of two parts: meta data for synchronization and a
# temporary buffer for storing intermediate allreduce results.
self.meta_ptrs = self.create_shared_buffer(ops.meta_size() + max_size,
group=group,
uncached=True)
# This is a pre-registered IPC buffer. In eager mode, input tensors
# are first copied into this buffer before allreduce is performed
self.buffer_ptrs = self.create_shared_buffer(max_size, group=group)
# This is a buffer for storing the tuples of pointers pointing to
# IPC buffers from all ranks. Each registered tuple has size of
# 8*world_size bytes where world_size is at most 8. Allocating 8MB
# is enough for 131072 such tuples. The largest model I've seen only
# needs less than 10000 of registered tuples.
self.rank_data = torch.empty(8 * 1024 * 1024,
dtype=torch.uint8,
device=self.device)
self.max_size = max_size
self.rank = rank
self.world_size = world_size
self.fully_connected = fully_connected
self._ptr = ops.init_custom_ar(self.meta_ptrs, self.rank_data, rank,
self.fully_connected)
ops.register_buffer(self._ptr, self.buffer_ptrs)
@contextmanager
def capture(self):
"""
The main responsibility of this context manager is the
`register_graph_buffers` call at the end of the context.
It records all the buffer addresses used in the CUDA graph.
"""
try:
self._IS_CAPTURING = True
yield
finally:
self._IS_CAPTURING = False
if not self.disabled:
self.register_graph_buffers()
def register_graph_buffers(self):
handle, offset = ops.get_graph_buffer_ipc_meta(self._ptr)
logger.info("Registering %d cuda graph addresses", len(offset))
# We cannot directly use `dist.all_gather_object` here
# because it is incompatible with `gloo` backend under inference mode.
# see https://github.com/pytorch/pytorch/issues/126032 for details.
all_data = [[None, None]
for _ in range(dist.get_world_size(group=self.group))]
all_data[self.rank] = [handle, offset]
ranks = sorted(dist.get_process_group_ranks(group=self.group))
for i, rank in enumerate(ranks):
dist.broadcast_object_list(all_data[i],
src=rank,
group=self.group,
device="cpu")
# Unpack list of tuples to tuple of lists.
handles = [d[0] for d in all_data] # type: ignore
offsets = [d[1] for d in all_data] # type: ignore
ops.register_graph_buffers(self._ptr, handles, offsets)
def should_custom_ar(self, inp: torch.Tensor):
if self.disabled:
return False
inp_size = inp.numel() * inp.element_size()
# custom allreduce requires input byte size to be multiples of 16
if inp_size % 16 != 0:
return False
if not is_weak_contiguous(inp):
return False
# for 4 or more non NVLink-capable GPUs, custom allreduce provides
# little performance improvement over NCCL.
return inp_size <= self.max_size
def all_reduce(self,
inp: torch.Tensor,
*,
out: torch.Tensor = None,
registered: bool = False):
"""Performs an out-of-place all reduce.
If registered is True, this assumes inp's pointer is already
IPC-registered. Otherwise, inp is first copied into a pre-registered
buffer.
"""
if out is None:
out = torch.empty_like(inp)
if registered:
ops.all_reduce(self._ptr, inp, out, 0, 0)
else:
ops.all_reduce(self._ptr, inp, out, self.buffer_ptrs[self.rank],
self.max_size)
return out
def custom_all_reduce(self, input: torch.Tensor) -> Optional[torch.Tensor]:
"""The main allreduce API that provides support for cuda graph."""
# When custom allreduce is disabled, this will be None.
if self.disabled or not self.should_custom_ar(input):
return None
if self._IS_CAPTURING:
if torch.cuda.is_current_stream_capturing():
return self.all_reduce(input, registered=False)
else:
# If warm up, mimic the allocation pattern since custom
# allreduce is out-of-place.
return torch.empty_like(input)
else:
# Note: outside of cuda graph context, custom allreduce incurs a
# cost of cudaMemcpy, which should be small (<=1% of overall
# latency) compared to the performance gain of using custom kernels
return self.all_reduce(input, registered=False)
def close(self):
if not self.disabled and self._ptr:
if ops is not None:
ops.dispose(self._ptr)
self._ptr = 0
self.free_shared_buffer(self.meta_ptrs, rank=self.rank)
self.free_shared_buffer(self.buffer_ptrs, rank=self.rank)
def __del__(self):
self.close()
@staticmethod
def create_shared_buffer(size_in_bytes: int,
group: Optional[ProcessGroup] = None,
uncached: Optional[bool] = False) -> list[int]:
pointer, handle = ops.allocate_shared_buffer_and_handle(size_in_bytes)
world_size = dist.get_world_size(group=group)
rank = dist.get_rank(group=group)
handles = [None] * world_size
dist.all_gather_object(handles, handle, group=group)
pointers: list[int] = []
for i, h in enumerate(handles):
if i == rank:
pointers.append(pointer) # type: ignore
else:
pointers.append(ops.open_mem_handle(h))
return pointers
@staticmethod
def free_shared_buffer(pointers: list[int],
group: Optional[ProcessGroup] = None,
rank: Optional[int] = 0) -> None:
if rank is None:
rank = dist.get_rank(group=group)
if ops is not None:
ops.free_shared_buffer(pointers[rank])

View File

@@ -53,6 +53,7 @@ from sglang.srt.utils import (
is_xpu, is_xpu,
supports_custom_op, supports_custom_op,
) )
from sglang.srt import _custom_ops as ops
_is_npu = is_npu() _is_npu = is_npu()
_is_cpu = is_cpu() _is_cpu = is_cpu()
@@ -303,7 +304,7 @@ class GroupCoordinator:
# Lazy import to avoid documentation build error # Lazy import to avoid documentation build error
from sglang.srt.distributed.device_communicators.custom_all_reduce import ( from sglang.srt.distributed.device_communicators.custom_all_reduce import (
CustomAllreduce, CustomAllreduce, DCUCustomAllreduce
) )
from sglang.srt.distributed.device_communicators.pymscclpp import ( from sglang.srt.distributed.device_communicators.pymscclpp import (
PyMscclppCommunicator, PyMscclppCommunicator,
@@ -347,11 +348,17 @@ class GroupCoordinator:
else: else:
ca_max_size = 8 * 1024 * 1024 ca_max_size = 8 * 1024 * 1024
try: try:
self.ca_comm = CustomAllreduce( if is_hip() and ops.use_dcu_custom_allreduce:
group=self.cpu_group, self.ca_comm = DCUCustomAllreduce(
device=self.device, group=self.cpu_group,
max_size=ca_max_size, device=self.device,
) )
else:
self.ca_comm = CustomAllreduce(
group=self.cpu_group,
device=self.device,
max_size=ca_max_size,
)
except Exception as e: except Exception as e:
logger.warning( logger.warning(
f"Setup Custom allreduce failed with {e}. To silence this " f"Setup Custom allreduce failed with {e}. To silence this "

View File

@@ -99,7 +99,6 @@ def create_triton_backend(runner):
return TritonAttnBackend(runner) return TritonAttnBackend(runner)
@register_attention_backend("torch_native") @register_attention_backend("torch_native")
def create_torch_native_backend(runner): def create_torch_native_backend(runner):
from sglang.srt.layers.attention.torch_native_backend import TorchNativeAttnBackend from sglang.srt.layers.attention.torch_native_backend import TorchNativeAttnBackend
@@ -120,6 +119,11 @@ def create_flashmla_backend(runner):
return FlashMLABackend(runner) return FlashMLABackend(runner)
@register_attention_backend("dcu_mla")
def create_dcu_mla_backend(runner):
from sglang.srt.layers.attention.dcu_mla_backend import DCUMLABackend
return DCUMLABackend(runner)
@register_attention_backend("fa3") @register_attention_backend("fa3")
def create_flashattention_v3_backend(runner): def create_flashattention_v3_backend(runner):

View File

@@ -0,0 +1,484 @@
from __future__ import annotations
from dataclasses import dataclass
from typing import TYPE_CHECKING, Optional, Tuple, Union
import torch
import triton
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
from sglang.srt.layers.attention.utils import create_flashmla_kv_indices_triton
from sglang.srt.layers.dp_attention import get_attention_tp_size
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
try:
from flash_mla import (
flash_mla_with_kvcache,
flash_mla_with_kvcache_quantization,
get_mla_metadata
)
_has_flash_mla = True
except Exception:
try:
from vllm.attention.ops.flashmla import (
flash_mla_with_kvcache,
get_mla_metadata
)
_has_flash_mla = False
except Exception:
raise ImportError(
"Can not import FlashMLA。Please perform the following operations to use flashmla:\n"
" pip install flash-mla\n"
" or\n"
" pip install vllm"
)
PAGE_SIZE = 64 # 强制64
if TYPE_CHECKING:
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.model_executor.model_runner import ModelRunner
from sglang.srt.speculative.spec_info import SpecInput
@dataclass
class VllmMLADecodeMetadata:
flashmla_metadata: Optional[Tuple[torch.Tensor, torch.Tensor]] = None
num_splits: Optional[torch.Tensor] = None
block_kv_indices: Optional[torch.Tensor] = None
class DCUMLABackend(AttentionBackend):
def __init__(
self,
model_runner: "ModelRunner",
skip_prefill: bool = False,
kv_indptr_buf: Optional[torch.Tensor] = None,
kv_last_page_len_buf: Optional[torch.Tensor] = None,
):
super().__init__()
if model_runner.server_args.page_size != PAGE_SIZE:
raise ValueError(
f"dcu_mla backend requires page_size={PAGE_SIZE}, "
f"but got the {model_runner.server_args.page_size}"
)
self.num_q_heads = (
model_runner.model_config.num_attention_heads // get_attention_tp_size()
)
self.req_to_token = model_runner.req_to_token_pool.req_to_token
self.kv_lora_rank = model_runner.model_config.kv_lora_rank
self.qk_nope_head_dim = model_runner.model_config.qk_nope_head_dim
self.qk_rope_head_dim = model_runner.model_config.qk_rope_head_dim
self.v_head_dim = model_runner.model_config.v_head_dim
self.kv_cache_dim = self.kv_lora_rank + self.qk_rope_head_dim
self.data_type = model_runner.kv_cache_dtype
self.q_data_type = model_runner.dtype
self.device = model_runner.device
self.max_context_len = model_runner.model_config.context_len
self.num_draft_tokens = model_runner.server_args.speculative_num_draft_tokens
self.forward_metadata: Union[VllmMLADecodeMetadata] = None
self.skip_prefill = skip_prefill
if not skip_prefill:
# 先用triton backend后面考虑替换
# from sglang.srt.layers.attention.triton_backend import TritonAttnBackend
# self.triton_backend = TritonAttnBackend(
# model_runner,
# skip_prefill=False,
# kv_indptr_buf=kv_indptr_buf,
# )
# prefill改用flash attn
from sglang.srt.layers.attention.flashattention_backend import FlashAttentionBackend
self.flashattn_backend = FlashAttentionBackend(
model_runner,
skip_prefill=False,
)
def _build_decode_metadata(
self,
forward_batch: ForwardBatch,
seq_lens: torch.Tensor
) -> Tuple[Tuple[torch.Tensor, torch.Tensor], torch.Tensor, torch.Tensor]:
bs = forward_batch.batch_size
max_seqlen_pad = triton.cdiv(seq_lens.max().item(), PAGE_SIZE)
# 参考vllm官方博客分页
block_kv_indices = torch.full(
(bs, max_seqlen_pad), -1, dtype=torch.int32, device=seq_lens.device
)
create_flashmla_kv_indices_triton[(bs,)](
self.req_to_token,
forward_batch.req_pool_indices,
seq_lens,
None,
block_kv_indices,
self.req_to_token.stride(0),
max_seqlen_pad,
)
mla_metadata, num_splits = get_mla_metadata(
seq_lens.to(torch.int32), self.num_q_heads, 1
)
return (mla_metadata, num_splits), num_splits, block_kv_indices
def init_forward_metadata(self, forward_batch: ForwardBatch):
if forward_batch.forward_mode.is_decode_or_idle():
# decode用flashmla
(mla_metadata, num_splits), num_splits_t, block_kv_indices = (
self._build_decode_metadata(forward_batch, forward_batch.seq_lens)
)
self.forward_metadata = VllmMLADecodeMetadata(
mla_metadata, num_splits_t, block_kv_indices
)
elif forward_batch.forward_mode.is_target_verify():
seq_lens = forward_batch.seq_lens + self.num_draft_tokens
(mla_metadata, num_splits), num_splits_t, block_kv_indices = (
self._build_decode_metadata(forward_batch, seq_lens)
)
self.forward_metadata = VllmMLADecodeMetadata(
mla_metadata, num_splits_t, block_kv_indices
)
else:
# prefill/extend用triton backend -> 改用flash attn
if not self.skip_prefill:
# self.triton_backend.init_forward_metadata(forward_batch)
self.flashattn_backend.init_forward_metadata(forward_batch)
def init_cuda_graph_state(
self,
max_bs: int,
max_num_tokens: int,
block_kv_indices: Optional[torch.Tensor] = None,
):
if block_kv_indices is None:
cuda_graph_kv_indices = torch.full(
(max_bs, (self.max_context_len + PAGE_SIZE) // PAGE_SIZE),
1,
dtype=torch.int32,
device="cuda",
)
else:
cuda_graph_kv_indices = block_kv_indices
if self.num_draft_tokens:
mla_metadata, num_splits = get_mla_metadata(
torch.ones(max_bs, dtype=torch.int32, device=cuda_graph_kv_indices.device),
self.num_draft_tokens * self.num_q_heads,
1,
)
else:
mla_metadata, num_splits = get_mla_metadata(
torch.ones(max_bs, dtype=torch.int32, device=cuda_graph_kv_indices.device),
self.num_q_heads,
1,
)
self.cuda_graph_mla_metadata = mla_metadata
self.cuda_graph_num_splits = num_splits
self.cuda_graph_kv_indices = cuda_graph_kv_indices
def init_forward_metadata_capture_cuda_graph(
self,
bs: int,
num_tokens: int,
req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor,
encoder_lens: Optional[torch.Tensor],
forward_mode: ForwardMode,
spec_info: Optional["SpecInput"],
):
if forward_mode.is_decode_or_idle():
max_seqlen_pad = triton.cdiv(seq_lens.max().item(), PAGE_SIZE)
create_flashmla_kv_indices_triton[(bs,)](
self.req_to_token,
req_pool_indices,
seq_lens,
None,
self.cuda_graph_kv_indices,
self.req_to_token.stride(0),
self.cuda_graph_kv_indices.stride(0),
)
num_q_heads = self.num_q_heads * (self.num_draft_tokens or 1)
mla_metadata, num_splits = get_mla_metadata(
seq_lens.to(torch.int32), num_q_heads, 1
)
self.cuda_graph_mla_metadata.copy_(mla_metadata)
self.cuda_graph_num_splits[: bs + 1].copy_(num_splits)
self.forward_metadata = VllmMLADecodeMetadata(
self.cuda_graph_mla_metadata,
self.cuda_graph_num_splits[: bs + 1],
self.cuda_graph_kv_indices[:bs, :max_seqlen_pad],
)
elif forward_mode.is_target_verify():
seq_lens = seq_lens + self.num_draft_tokens
max_seqlen_pad = triton.cdiv(seq_lens.max().item(), PAGE_SIZE)
create_flashmla_kv_indices_triton[(bs,)](
self.req_to_token,
req_pool_indices,
seq_lens,
None,
self.cuda_graph_kv_indices,
self.req_to_token.stride(0),
self.cuda_graph_kv_indices.stride(0),
)
mla_metadata, num_splits = get_mla_metadata(
seq_lens.to(torch.int32), self.num_draft_tokens * self.num_q_heads, 1
)
self.cuda_graph_mla_metadata.copy_(mla_metadata)
self.cuda_graph_num_splits[: bs + 1].copy_(num_splits)
self.forward_metadata = VllmMLADecodeMetadata(
self.cuda_graph_mla_metadata,
self.cuda_graph_num_splits[: bs + 1],
self.cuda_graph_kv_indices[:bs, :max_seqlen_pad],
)
else:
if not self.skip_prefill:
# self.triton_backend.init_forward_metadata_capture_cuda_graph(
# bs,
# num_tokens,
# req_pool_indices,
# seq_lens,
# encoder_lens,
# forward_mode,
# spec_info,
# )
self.flashattn_backend.init_forward_metadata_capture_cuda_graph(
bs,
num_tokens,
req_pool_indices,
seq_lens,
encoder_lens,
forward_mode,
spec_info,
)
def init_forward_metadata_replay_cuda_graph(
self,
bs: int,
req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor,
seq_lens_sum: int,
encoder_lens: Optional[torch.Tensor],
forward_mode: ForwardMode,
spec_info: Optional["SpecInput"],
seq_lens_cpu: Optional[torch.Tensor],
):
if forward_mode.is_decode_or_idle():
assert seq_lens_cpu is not None
seq_lens = seq_lens[:bs]
seq_lens_cpu = seq_lens_cpu[:bs]
max_seqlen_pad = triton.cdiv(seq_lens_cpu.max().item(), PAGE_SIZE)
create_flashmla_kv_indices_triton[(bs,)](
self.req_to_token,
req_pool_indices[:bs],
seq_lens,
None,
self.cuda_graph_kv_indices,
self.req_to_token.stride(0),
self.cuda_graph_kv_indices.stride(0),
)
num_q_heads = self.num_q_heads * (self.num_draft_tokens or 1)
mla_metadata, num_splits = get_mla_metadata(
seq_lens.to(torch.int32), num_q_heads, 1
)
self.cuda_graph_mla_metadata.copy_(mla_metadata)
self.cuda_graph_num_splits[: bs + 1].copy_(num_splits)
self.forward_metadata.flashmla_metadata = self.cuda_graph_mla_metadata
self.forward_metadata.num_splits = self.cuda_graph_num_splits[: bs + 1]
self.forward_metadata.block_kv_indices = self.cuda_graph_kv_indices[
:bs, :max_seqlen_pad
]
elif forward_mode.is_target_verify():
seq_lens = seq_lens[:bs] + self.num_draft_tokens
seq_lens_cpu = seq_lens_cpu[:bs] + self.num_draft_tokens
max_seqlen_pad = triton.cdiv(seq_lens_cpu.max().item(), PAGE_SIZE)
create_flashmla_kv_indices_triton[(bs,)](
self.req_to_token,
req_pool_indices[:bs],
seq_lens,
None,
self.cuda_graph_kv_indices,
self.req_to_token.stride(0),
self.cuda_graph_kv_indices.stride(0),
)
mla_metadata, num_splits = get_mla_metadata(
seq_lens.to(torch.int32), self.num_draft_tokens * self.num_q_heads, 1
)
self.cuda_graph_mla_metadata.copy_(mla_metadata)
self.cuda_graph_num_splits[: bs + 1].copy_(num_splits)
self.forward_metadata.flashmla_metadata = self.cuda_graph_mla_metadata
self.forward_metadata.num_splits = self.cuda_graph_num_splits[: bs + 1]
self.forward_metadata.block_kv_indices = self.cuda_graph_kv_indices[
:bs, :max_seqlen_pad
]
else:
if not self.skip_prefill:
# self.triton_backend.init_forward_metadata_replay_cuda_graph(
# bs,
# req_pool_indices,
# seq_lens,
# seq_lens_sum,
# encoder_lens,
# forward_mode,
# spec_info,
# seq_lens_cpu,
# )
self.flashattn_backend.init_forward_metadata_replay_cuda_graph(
bs,
req_pool_indices,
seq_lens,
seq_lens_sum,
encoder_lens,
forward_mode,
spec_info,
seq_lens_cpu,
)
def get_cuda_graph_seq_len_fill_value(self):
return 1
def _call_decode(self, reshape_q: torch.Tensor, k_cache_reshaped: torch.Tensor,
block_table: torch.Tensor, cache_seqlens: torch.Tensor,
scaling: float):
o, _ = flash_mla_with_kvcache(
q=reshape_q,
k_cache=k_cache_reshaped,
block_table=block_table,
cache_seqlens=cache_seqlens,
head_dim_v=self.kv_lora_rank,
tile_scheduler_metadata=self.forward_metadata.flashmla_metadata,
num_splits=self.forward_metadata.num_splits,
softmax_scale=scaling,
causal=True,
)
return o
def _call_fp8_decode(self, reshape_q: torch.Tensor, k_cache_reshaped: torch.Tensor,
block_table: torch.Tensor, cache_seqlens: torch.Tensor,
scaling: float):
assert _has_flash_mla, "FP8 KV cache 需要flash_mla包"
o, _ = flash_mla_with_kvcache_quantization(
q=reshape_q,
k_cache=k_cache_reshaped,
block_table=block_table,
cache_seqlens=cache_seqlens,
head_dim_v=self.kv_lora_rank,
tile_scheduler_metadata=self.forward_metadata.flashmla_metadata,
num_splits=self.forward_metadata.num_splits,
softmax_scale=scaling,
causal=True,
is_fp8_kvcache=True,
)
return o
def forward_decode(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
layer: "RadixAttention",
forward_batch: ForwardBatch,
save_kv_cache: bool = True,
):
cache_loc = forward_batch.out_cache_loc
if k is not None:
assert v is not None
if save_kv_cache:
forward_batch.token_to_kv_pool.set_kv_buffer(
layer,
cache_loc,
k,
v,
)
bs = forward_batch.batch_size
k_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
reshape_q = q.view(bs, -1, layer.tp_q_head_num, layer.head_dim)
k_cache_reshaped = k_cache.view(-1, PAGE_SIZE, 1, self.kv_cache_dim)
if self.data_type in (
getattr(torch, "float8_e4m3fn", None),
getattr(torch, "float8_e4m3fnuz", None),
getattr(torch, "float8_e5m2", None),
getattr(torch, "float8_e5m2fnuz", None),
):
o = self._call_fp8_decode(
reshape_q, k_cache_reshaped, self.forward_metadata.block_kv_indices[:bs],
forward_batch.seq_lens.to(torch.int32), layer.scaling,
)
else:
o = self._call_decode(
reshape_q, k_cache_reshaped, self.forward_metadata.block_kv_indices[:bs],
forward_batch.seq_lens.to(torch.int32), layer.scaling,
)
return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
def forward_extend(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
layer: "RadixAttention",
forward_batch: ForwardBatch,
save_kv_cache: bool = True,
sinks=None,
):
if (
forward_batch.forward_mode == ForwardMode.EXTEND
or forward_batch.forward_mode == ForwardMode.DRAFT_EXTEND
):
# flash_attn不支持fp8fp8无法正常执行extend
if not self.skip_prefill:
# return self.triton_backend.forward_extend(
# q, k, v, layer, forward_batch, save_kv_cache, sinks
# )
return self.flashattn_backend.forward_extend(
q, k, v, layer, forward_batch, save_kv_cache, sinks
)
else:
raise RuntimeError("skip prefill but use forward_extend")
cache_loc = forward_batch.out_cache_loc
if k is not None:
assert v is not None
if save_kv_cache:
forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v)
bs = forward_batch.batch_size
k_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
reshape_q = q.view(bs, -1, layer.tp_q_head_num, layer.head_dim)
k_cache_reshaped = k_cache.view(-1, PAGE_SIZE, 1, self.kv_cache_dim)
if self.data_type in (
getattr(torch, "float8_e4m3fn", None),
getattr(torch, "float8_e4m3fnuz", None),
getattr(torch, "float8_e5m2", None),
getattr(torch, "float8_e5m2fnuz", None),
):
o = self._call_fp8_decode(
reshape_q, k_cache_reshaped, self.forward_metadata.block_kv_indices[:bs],
(forward_batch.seq_lens + self.num_draft_tokens).to(torch.int32),
layer.scaling,
)
else:
o = self._call_decode(
reshape_q, k_cache_reshaped, self.forward_metadata.block_kv_indices[:bs],
(forward_batch.seq_lens + self.num_draft_tokens).to(torch.int32),
layer.scaling,
)
return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)

View File

@@ -9,7 +9,8 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from sgl_kernel.flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache # from sgl_kernel.flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache
from sglang.srt.layers.attention.flashattention_interface import flash_attn_varlen_func, flash_attn_with_kvcache
from sgl_kernel.sparse_flash_attn import ( from sgl_kernel.sparse_flash_attn import (
convert_vertical_slash_indexes, convert_vertical_slash_indexes,
convert_vertical_slash_indexes_mergehead, convert_vertical_slash_indexes_mergehead,

View File

@@ -20,7 +20,8 @@ if TYPE_CHECKING:
from sglang.srt.model_executor.model_runner import ModelRunner from sglang.srt.model_executor.model_runner import ModelRunner
from sgl_kernel import merge_state_v2 from sgl_kernel import merge_state_v2
from sgl_kernel.flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache # from sgl_kernel.flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache
from sglang.srt.layers.attention.flashattention_interface import flash_attn_varlen_func, flash_attn_with_kvcache
@dataclass @dataclass

View File

@@ -0,0 +1,94 @@
from flash_attn import (
flash_attn_varlen_func as flash_attn_varlen_func_interface,
flash_attn_with_kvcache as flash_attn_with_kvcache_interface
)
from typing import Optional, Union
import torch
def flash_attn_with_kvcache(
q,
k_cache,
v_cache,
k=None,
v=None,
qv=None,
rotary_cos=None,
rotary_sin=None,
cache_seqlens: Optional[Union[int, torch.Tensor]] = None,
cache_batch_idx: Optional[torch.Tensor] = None,
cache_leftpad: Optional[torch.Tensor] = None,
page_table: Optional[torch.Tensor] = None,
cu_seqlens_q: Optional[torch.Tensor] = None,
cu_seqlens_k_new: Optional[torch.Tensor] = None,
max_seqlen_q: Optional[int] = None,
rotary_seqlens: Optional[torch.Tensor] = None,
q_descale: Optional[torch.Tensor] = None,
k_descale: Optional[torch.Tensor] = None,
v_descale: Optional[torch.Tensor] = None,
softmax_scale=None,
causal=False,
window_size=(-1, -1), # -1 means infinite context window
attention_chunk: Optional[int] = None,
softcap=0.0, # 0.0 means deactivated
rotary_interleaved=True,
scheduler_metadata=None,
num_splits=0, # Can be tuned for speed
pack_gqa=None, # Can be tuned for speed
sm_margin=0, # Can be tuned if some SMs are used for communication
return_softmax_lse=False,
sinks=None,
ver=3,
):
return flash_attn_with_kvcache_interface(
q=q.contiguous().view(-1, max_seqlen_q, q.shape[-2], q.shape[-1]),
k_cache=k_cache,
v_cache=v_cache,
block_table=page_table,
cache_seqlens=cache_seqlens,
softmax_scale=softmax_scale,
causal=causal,
window_size=window_size,
softcap=softcap,
return_softmax_lse=return_softmax_lse,
num_splits=num_splits,
)
def flash_attn_varlen_func(
q,
k,
v,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q=None,
max_seqlen_k=None,
seqused_q=None,
seqused_k=None,
page_table=None,
softmax_scale=None,
causal=False,
qv=None,
q_descale=None,
k_descale=None,
v_descale=None,
window_size=(-1, -1),
attention_chunk=0,
softcap=0.0,
num_splits=1,
pack_gqa=None,
sm_margin=0,
return_softmax_lse=False,
sinks=None,
ver=3,
):
return flash_attn_varlen_func_interface(
q=q,
k=k,
v=v,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_q,
max_seqlen_q=max_seqlen_q,
max_seqlen_k=max_seqlen_q,
softmax_scale=softmax_scale,
causal=causal,
)

View File

@@ -45,7 +45,8 @@ if _is_hip:
"aiter is AMD specific kernel library. Please make sure aiter is installed on your AMD device." "aiter is AMD specific kernel library. Please make sure aiter is installed on your AMD device."
) )
else: else:
from sgl_kernel.flash_attn import flash_attn_with_kvcache # from sgl_kernel.flash_attn import flash_attn_with_kvcache
from sglang.srt.layers.attention.flashattention_interface import flash_attn_with_kvcache
@dataclass(frozen=True) @dataclass(frozen=True)

View File

@@ -20,7 +20,8 @@ if TYPE_CHECKING:
from sglang.srt.model_executor.model_runner import ModelRunner from sglang.srt.model_executor.model_runner import ModelRunner
from sgl_kernel import merge_state_v2 from sgl_kernel import merge_state_v2
from sgl_kernel.flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache # from sgl_kernel.flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache
from sglang.srt.layers.attention.flashattention_interface import flash_attn_varlen_func, flash_attn_with_kvcache
class XPUAttentionBackend(AttentionBackend): class XPUAttentionBackend(AttentionBackend):

View File

@@ -169,6 +169,14 @@ class RMSNorm(CustomOp):
try: try:
output = torch.empty_like(x) output = torch.empty_like(x)
residual_out = torch.empty_like(x) residual_out = torch.empty_like(x)
fused_add_rms_norm(
x,
residual,
self.weight.data,
self.variance_epsilon,
)
return x, residual
except TypeError:
fused_add_rms_norm( fused_add_rms_norm(
output, output,
x, x,
@@ -178,14 +186,7 @@ class RMSNorm(CustomOp):
self.variance_epsilon, self.variance_epsilon,
) )
return output, residual_out return output, residual_out
except TypeError:
fused_add_rms_norm(
x,
residual,
self.weight.data,
self.variance_epsilon,
)
return x, residual
out = torch.empty_like(x) out = torch.empty_like(x)
rms_norm(out, x, self.weight.data, self.variance_epsilon) rms_norm(out, x, self.weight.data, self.variance_epsilon)

View File

@@ -3,6 +3,7 @@ from __future__ import annotations
import logging import logging
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
from sglang.srt.layers.quantization.slimquant_w4a8_marlin import SlimQuantW4A8Int8MarlinConfig
import torch import torch
from sglang.srt import single_batch_overlap from sglang.srt import single_batch_overlap
@@ -54,7 +55,286 @@ if _use_aiter:
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class DeepEPMoE(FusedMoE): # TODO(kaixih@nvidia): ideally we should merge this logic into
# `fill_gateup_input_triton_kernel` to directly generate e8m0 scale.
@torch.compile
def _cast_to_e8m0_with_rounding_up(x: torch.Tensor) -> torch.Tensor:
temp = x.to(torch.float32).view(torch.int32)
exp = torch.bitwise_right_shift(temp, 23)
mant = torch.bitwise_and(temp, 0x7FFFFF)
is_ru = torch.logical_and(
torch.logical_and((mant > 0), (exp != 0xFE)),
~torch.logical_and((exp == 0), (mant <= 0x400000)),
)
exp = torch.where(is_ru, exp + 1, exp)
new_x = exp.to(torch.uint8).view(torch.int)
return new_x.transpose(1, 2).contiguous().transpose(1, 2)
class EPMoE(FusedMoE):
"""
MoE Expert Parallel Impl
"""
def __init__(
self,
num_experts: int,
top_k: int,
hidden_size: int,
intermediate_size: int,
layer_id: int,
num_fused_shared_experts: int = 0,
params_dtype: Optional[torch.dtype] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
activation: str = "silu",
routed_scaling_factor: Optional[float] = None,
gemm1_alpha: Optional[float] = None,
gemm1_clamp_limit: Optional[float] = None,
with_bias: bool = False,
):
super().__init__(
num_experts=num_experts,
hidden_size=hidden_size,
intermediate_size=intermediate_size,
num_fused_shared_experts=num_fused_shared_experts,
layer_id=layer_id,
top_k=top_k,
params_dtype=params_dtype,
quant_config=quant_config,
prefix=prefix,
activation=activation,
# apply_router_weight_on_input=apply_router_weight_on_input,
routed_scaling_factor=routed_scaling_factor,
gemm1_alpha=gemm1_alpha,
gemm1_clamp_limit=gemm1_clamp_limit,
with_bias=with_bias,
)
self.intermediate_size = intermediate_size
if isinstance(quant_config, Fp8Config):
self.use_block_quant = getattr(self.quant_method, "block_quant", False)
self.block_shape = (
self.quant_method.quant_config.weight_block_size
if self.use_block_quant
else None
)
self.use_fp8_w8a8 = True
self.fp8_dtype = torch.float8_e4m3fn
self.activation_scheme = quant_config.activation_scheme
self.use_w4a8_marlin = False
elif isinstance(quant_config, SlimQuantW4A8Int8MarlinConfig):
self.use_block_quant = getattr(self.quant_method, "block_quant", False)
self.block_shape = (
self.quant_method.quant_config.weight_block_size
if self.use_block_quant
else None
)
self.use_fp8_w8a8 = False
self.activation_scheme = None
self.use_w4a8_marlin = True
else:
self.use_fp8_w8a8 = False
self.use_block_quant = False
self.block_shape = None
self.activation_scheme = None
self.use_w4a8_marlin = False
def forward(self, hidden_states: torch.Tensor, topk_output: TopKOutput):
if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and self.use_fp8_w8a8:
return self.forward_deepgemm(hidden_states, topk_output)
else:
return super().forward(hidden_states, topk_output)
def forward_deepgemm(
self,
hidden_states: torch.Tensor,
topk_output: TopKOutput,
):
self.w13_weight_fp8 = (
self.w13_weight,
(
self.w13_weight_scale_inv
if self.use_block_quant
else self.w13_weight_scale
),
)
self.w2_weight_fp8 = (
self.w2_weight,
self.w2_weight_scale_inv if self.use_block_quant else self.w2_weight_scale,
)
assert self.quant_method is not None
assert self.moe_runner_config.activation == "silu"
hidden_states_shape = hidden_states.shape
hidden_states_dtype = hidden_states.dtype
hidden_states_device = hidden_states.device
topk_weights, topk_ids, _ = topk_output
if not self.use_block_quant:
# Convert per-tensor quant to per-block quant by repeating scales for forward_deepgemm
scale_block_size = 128
w13_weight_scale_n = 2 * (
(self.intermediate_size + scale_block_size - 1) // scale_block_size
)
w13_weight_scale_k = (
hidden_states_shape[-1] + scale_block_size - 1
) // scale_block_size
w13_weight_scale = (
self.w13_weight_scale.unsqueeze(1)
.repeat_interleave(w13_weight_scale_n, dim=1)
.unsqueeze(2)
.repeat_interleave(w13_weight_scale_k, dim=2)
)
self.w13_weight_fp8 = (
self.w13_weight,
w13_weight_scale,
)
w2_weight_scale_n = (
hidden_states_shape[-1] + scale_block_size - 1
) // scale_block_size
w2_weight_scale_k = (
self.intermediate_size + scale_block_size - 1
) // scale_block_size
w2_weight_scale = (
self.w2_weight_scale.unsqueeze(1)
.repeat_interleave(w2_weight_scale_n, dim=1)
.unsqueeze(2)
.repeat_interleave(w2_weight_scale_k, dim=2)
)
self.w2_weight_fp8 = (
self.w2_weight,
w2_weight_scale,
)
# PreReorder
m_max, masked_m, expected_m, src2dst, gateup_input, gateup_input_scale = (
moe_ep_deepgemm_preprocess(
topk_ids,
self.num_experts,
hidden_states,
self.top_k,
self.start_expert_id,
self.end_expert_id,
self.block_shape,
)
)
dispose_tensor(hidden_states)
if deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0:
b, s_mn, s_k = gateup_input_scale.shape
assert (
s_mn % 4 == 0 and s_k % 4 == 0
), f"scales must be aligned to 4, but got ({b}, {s_mn}, {s_k})"
# GroupGemm-0
gateup_input_fp8 = (
gateup_input,
(
_cast_to_e8m0_with_rounding_up(gateup_input_scale)
if deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0
else deep_gemm_wrapper.get_mn_major_tma_aligned_tensor(
gateup_input_scale
)
),
)
num_groups, m, k = gateup_input_fp8[0].size()
n = self.w13_weight.size(1)
gateup_output = torch.empty(
(num_groups, m, n), device=hidden_states_device, dtype=torch.bfloat16
)
deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_masked(
gateup_input_fp8,
self.w13_weight_fp8,
gateup_output,
masked_m,
expected_m,
)
del gateup_input
del gateup_input_fp8
# Act
down_input = torch.empty(
(
gateup_output.shape[0],
gateup_output.shape[1],
gateup_output.shape[2] // 2,
),
device=hidden_states_device,
dtype=self.fp8_dtype,
)
scale_block_size = 128
down_input_scale = torch.empty(
(
gateup_output.shape[0],
gateup_output.shape[1],
gateup_output.shape[2] // 2 // scale_block_size,
),
device=hidden_states_device,
dtype=torch.float32,
)
silu_and_mul_masked_post_quant_fwd(
gateup_output,
down_input,
down_input_scale,
scale_block_size,
masked_m,
scale_ue8m0=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0,
)
del gateup_output
# GroupGemm-1
n = self.w2_weight.size(1)
down_input_fp8 = (
down_input,
(
down_input_scale
if deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0
else deep_gemm_wrapper.get_mn_major_tma_aligned_tensor(down_input_scale)
),
)
down_output = torch.empty(
(num_groups, m, n), device=hidden_states_device, dtype=torch.bfloat16
)
deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_masked(
down_input_fp8,
self.w2_weight_fp8,
down_output,
masked_m,
expected_m,
)
del down_input
del down_input_fp8
# PostReorder
output = torch.empty(
hidden_states_shape, dtype=hidden_states_dtype, device=hidden_states_device
)
post_reorder_triton_kernel[(hidden_states_shape[0],)](
down_output,
output,
src2dst,
topk_ids,
topk_weights,
self.start_expert_id,
self.end_expert_id,
self.top_k,
hidden_states_shape[1],
m_max * self.start_expert_id,
BLOCK_SIZE=512,
)
if self.moe_runner_config.routed_scaling_factor is not None:
output *= self.moe_runner_config.routed_scaling_factor
return output
class DeepEPMoE(EPMoE):
""" """
MoE Expert Parallel Impl based on DeepEP (https://github.com/deepseek-ai/DeepEP/tree/main) MoE Expert Parallel Impl based on DeepEP (https://github.com/deepseek-ai/DeepEP/tree/main)
Mooncake EP shares the same class, as they expose the same interface. Mooncake EP shares the same class, as they expose the same interface.
@@ -106,11 +386,28 @@ class DeepEPMoE(FusedMoE):
self.deepep_mode = get_deepep_mode() self.deepep_mode = get_deepep_mode()
if self.deepep_mode.enable_low_latency() and not _is_npu: # TODO: move to the beginning of the file
# NPU supports low_latency deepep without deepgemm from sglang.srt.distributed.parallel_state import get_tp_group
assert ( from sglang.srt.two_batch_overlap import MaybeTboDeepEPDispatcher
deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM
), f"DeepEP {self.deepep_mode} mode requires deep_gemm" self.deepep_dispatcher = MaybeTboDeepEPDispatcher(
group=get_tp_group().device_group,
router_topk=self.top_k,
permute_fusion=True,
num_experts=self.num_experts,
num_local_experts=self.num_local_experts,
hidden_size=hidden_size,
params_dtype=params_dtype,
deepep_mode=self.deepep_mode,
async_finish=True, # TODO
return_recv_hook=True,
)
# if self.deepep_mode.enable_low_latency() and not _is_npu:
# # NPU supports low_latency deepep without deepgemm
# assert (
# deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM
# ), f"DeepEP {self.deepep_mode} mode requires deep_gemm"
if _use_aiter: if _use_aiter:
# expert_mask is of size (self.num_local_experts + 1), # expert_mask is of size (self.num_local_experts + 1),
# the extra 1 is for invalid rank_id (in original deepep, the invalid rank_id is -1, but aiter does not allow -1, we use a mask to make those ids invalid) # the extra 1 is for invalid rank_id (in original deepep, the invalid rank_id is -1, but aiter does not allow -1, we use a mask to make those ids invalid)
@@ -124,23 +421,23 @@ class DeepEPMoE(FusedMoE):
) )
# the last one is invalid rank_id # the last one is invalid rank_id
self.expert_mask[:-1] = 1 self.expert_mask[:-1] = 1
elif not _is_npu: # elif not _is_npu:
self.w13_weight_fp8 = ( # self.w13_weight_fp8 = (
self.w13_weight, # self.w13_weight,
( # (
self.w13_weight_scale_inv # self.w13_weight_scale_inv
if self.use_block_quant or self.use_w4afp8 # if self.use_block_quant
else self.w13_weight_scale # else self.w13_weight_scale
), # ),
) # )
self.w2_weight_fp8 = ( # self.w2_weight_fp8 = (
self.w2_weight, # self.w2_weight,
( # (
self.w2_weight_scale_inv # self.w2_weight_scale_inv
if self.use_block_quant or self.use_w4afp8 # if self.use_block_quant
else self.w2_weight_scale # else self.w2_weight_scale
), # ),
) # )
def forward( def forward(
self, self,
@@ -187,10 +484,15 @@ class DeepEPMoE(FusedMoE):
assert DispatchOutputChecker.format_is_deepep(dispatch_output) assert DispatchOutputChecker.format_is_deepep(dispatch_output)
return self.forward_npu(dispatch_output) return self.forward_npu(dispatch_output)
if DispatchOutputChecker.format_is_deepep_normal(dispatch_output): if DispatchOutputChecker.format_is_deepep_normal(dispatch_output):
if self.use_w4afp8: #assert deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and self.use_fp8_w8a8
return self.forward_cutlass_w4afp8(dispatch_output) if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and self.use_fp8_w8a8:
assert deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and self.use_fp8_w8a8 return self.forward_deepgemm_contiguous(dispatch_output)
return self.forward_deepgemm_contiguous(dispatch_output) elif self.use_w4a8_marlin:
return self.forward_deepgemm_w4a8_marlin_contiguous(dispatch_output)
else:
raise ValueError(
f"Dispatch output is not supported"
)
elif DispatchOutputChecker.format_is_deepep_ll(dispatch_output): elif DispatchOutputChecker.format_is_deepep_ll(dispatch_output):
if ( if (
get_moe_runner_backend().is_flashinfer_cutedsl() get_moe_runner_backend().is_flashinfer_cutedsl()
@@ -255,6 +557,34 @@ class DeepEPMoE(FusedMoE):
expert_mask=self.expert_mask, expert_mask=self.expert_mask,
) )
def forward_deepgemm_w4a8_marlin_contiguous(
self,
dispatch_output: DeepEPNormalOutput,
):
hidden_states_int8, topk_idx, topk_weights, num_recv_tokens_per_expert = (
dispatch_output
)
assert self.quant_method is not None
assert self.moe_runner_config.activation == "silu"
# if num_recv_tokens_per_expert is None:
return hidden_states_int8.bfloat16()
# expert_output = self.quant_method.apply_ep(
# layer=self,
# x=dispatch_output,
# topk_weights=topk_weights,
# topk_ids=topk_idx,
# global_num_experts=self.global_num_experts,
# expert_map=self.expert_map,
# activation=self.activation,
# apply_router_weight_on_input=self.apply_router_weight_on_input,
# use_nn_moe=self.use_nn_moe,
# num_local_tokens=dispatch_recv_num_token,
# config_select_bs=hidden_states.shape[0],
# scales=dispatch_scales if self.use_int8_dispatch else None
# # routed_scaling_factor=self.routed_scaling_factor,
# )
# return expert_output
def forward_deepgemm_contiguous( def forward_deepgemm_contiguous(
self, self,
dispatch_output: DeepEPNormalOutput, dispatch_output: DeepEPNormalOutput,

View File

@@ -14,9 +14,10 @@ from sglang.srt.layers.quantization.fp8_kernel import (
) )
from sglang.srt.layers.quantization.int8_kernel import ( from sglang.srt.layers.quantization.int8_kernel import (
per_token_group_quant_int8, per_token_group_quant_int8,
per_token_quant_int8, # per_token_quant_int8,
sglang_per_token_group_quant_int8, sglang_per_token_group_quant_int8,
) )
from lmslim.layers.gemm.int8_utils import per_token_quant_int8
from sglang.srt.utils import ( from sglang.srt.utils import (
cpu_has_amx_support, cpu_has_amx_support,
get_bool_env_var, get_bool_env_var,

View File

@@ -460,11 +460,32 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
overlap_args: Optional["CombineOverlapArgs"], overlap_args: Optional["CombineOverlapArgs"],
): ):
if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM or _use_aiter or _is_npu: #if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM or _use_aiter or _is_npu:
output = hidden_states output = hidden_states
else: # else:
raise NotImplementedError() # triton runner was supported but it's temporarily disabled # if hidden_states.shape[0] > 0:
# num_tokens = self.src2dst.shape[0] // self.router_topk
# output = torch.empty(
# (num_tokens, hidden_states.shape[1]),
# device=hidden_states.device,
# dtype=hidden_states.dtype,
# )
# deepep_post_reorder_triton_kernel[(num_tokens,)](
# hidden_states,
# output,
# self.src2dst,
# topk_idx,
# topk_weights,
# self.router_topk,
# hidden_states.shape[1],
# BLOCK_SIZE=512,
# )
# else:
# output = torch.zeros(
# (0, hidden_states.shape[1]),
# device=hidden_states.device,
# dtype=hidden_states.dtype,
# )
previous_event = Buffer.capture() if self.async_finish else None previous_event = Buffer.capture() if self.async_finish else None
return output, previous_event return output, previous_event

View File

@@ -57,6 +57,7 @@ from sglang.srt.layers.quantization.qoq import QoQConfig
from sglang.srt.layers.quantization.w4afp8 import W4AFp8Config from sglang.srt.layers.quantization.w4afp8 import W4AFp8Config
from sglang.srt.layers.quantization.w8a8_fp8 import W8A8Fp8Config from sglang.srt.layers.quantization.w8a8_fp8 import W8A8Fp8Config
from sglang.srt.layers.quantization.w8a8_int8 import W8A8Int8Config from sglang.srt.layers.quantization.w8a8_int8 import W8A8Int8Config
from sglang.srt.layers.quantization.slimquant_w4a8_marlin import SlimQuantW4A8Int8MarlinConfig
from sglang.srt.utils import is_cuda, is_hip, mxfp_supported from sglang.srt.utils import is_cuda, is_hip, mxfp_supported
_is_mxfp_supported = mxfp_supported() _is_mxfp_supported = mxfp_supported()
@@ -83,6 +84,7 @@ BASE_QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
"w4afp8": W4AFp8Config, "w4afp8": W4AFp8Config,
"petit_nvfp4": PetitNvFp4Config, "petit_nvfp4": PetitNvFp4Config,
"fbgemm_fp8": FBGEMMFp8Config, "fbgemm_fp8": FBGEMMFp8Config,
"slimquant_w4a8_marlin":SlimQuantW4A8Int8MarlinConfig,
} }

View File

@@ -0,0 +1,415 @@
from typing import Any, Callable, Dict, List, Optional
import torch
from sglang.srt.layers.linear import set_weight_attrs
from sglang.srt.distributed import get_tensor_model_parallel_world_size
from torch.nn.parameter import Parameter
from sglang.srt.layers.linear import LinearBase
from sglang.srt.layers.quantization.base_config import LinearMethodBase, QuantizationConfig, QuantizeMethodBase, FusedMoEMethodBase
from sglang.srt.layers.parameter import (
ChannelQuantScaleParameter,
_ColumnvLLMParameter,
RowvLLMParameter,
)
from lmslim.layers.gemm.int8_utils import (
per_token_group_quant_int8,
per_token_quant_int8)
from sglang.srt import _custom_ops as ops
from vllm.utils import W8a8GetCacheJSON
from sglang.srt.layers.moe import MoeRunner, MoeRunnerBackend, MoeRunnerConfig
import os
class ModelWeightParameter(_ColumnvLLMParameter, RowvLLMParameter):
"""
Parameter class for linear layer weights. Uses both column and
row parallelism.
"""
pass
W8A8_TRITONJSON=W8a8GetCacheJSON()
def baseline_scaled_mm(a: torch.Tensor,
b: torch.Tensor,
scale_a: torch.Tensor,
scale_b: torch.Tensor,
out_dtype: torch.dtype,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
scales= scale_a* scale_b.T
gemmout= torch.mm(
a.to(dtype=torch.float32), b.to(dtype=torch.float32))
output = (scales *gemmout).to(out_dtype)
if bias is not None:
output = output + bias
return output.to(out_dtype)
class SlimQuantW4A8Int8Config(QuantizationConfig):
"""Config class for W8A8 Int8 Quantization.
- Weight: static, per-channel, symmetric
- Activation: dynamic, per-token, symmetric
"""
def __init__(self):
pass
@classmethod
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
return [torch.float16, torch.bfloat16]
@classmethod
def get_min_capability(cls) -> int:
return 75
@classmethod
def get_name(self) -> str:
return "slimquant_w4a8"
@classmethod
def get_config_filenames(cls) -> List[str]:
return []
@classmethod
def from_config(cls, config: Dict[str, Any]) -> "SlimQuantW4A8Int8Config":
return cls()
def get_quant_method(
self,
layer: torch.nn.Module,
prefix: str,
) -> Optional["QuantizeMethodBase"]:
from sglang.srt.layers.moe.fused_moe_triton import (FusedMoE, FusedMoeWeightScaleSupported)
if isinstance(layer, LinearBase):
return SlimQuantW4A8Int8LinearMethod(self)
elif isinstance(layer, FusedMoE):
return SlimQuantW4A8Int8MoEMethod(self)
return None
def get_scaled_act_names(self) -> List[str]:
return []
class SlimQuantW4A8Int8LinearMethod(LinearMethodBase):
def __init__(self, quantization_config: SlimQuantW4A8Int8Config):
self.quantization_config = quantization_config
self.tritonsingleton= W8a8GetCacheJSON()
self.w8a8_strategy=int(os.getenv('W8A8_SUPPORT_METHODS', '1'))
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
n=layer.weight.shape[0]
k=layer.weight.shape[1]
if self.w8a8_strategy==1:
if {n,k} not in self.tritonsingleton.weight_shapes:
self.tritonsingleton.weight_shapes.append({n,k})
json_file=self.tritonsingleton.get_w8a8json_name(n,k)
configs_dict=self.tritonsingleton.get_triton_cache(json_file,n,k)
if configs_dict:
self.tritonsingleton.triton_json_dict.update(configs_dict)
for key, value in configs_dict.items():
m=int(key.split('_')[0])
ops.triton_int8_gemm_helper(m=m,n=n,k=k,per_token_act_quant=True,per_out_channel_weight_quant=True,use_bias=False,device=layer.weight.device,best_config=value)
else:
weight_data=layer.weight.data
_weight=weight_data.T.contiguous().reshape(n,-1)
layer.weight.data=_weight
layer.weight = Parameter(layer.weight.t(), requires_grad=False)
layer.weight_scale = Parameter(layer.weight_scale.data, requires_grad=False)
def create_weights(
self,
layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: List[int],
input_size: int,
output_size: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
weight_loader = extra_weight_attrs.get("weight_loader")
self.logical_widths = output_partition_sizes
weight = ModelWeightParameter(
data=torch.empty(
sum(output_partition_sizes), input_size_per_partition, dtype=torch.int8
),
input_dim=1,
output_dim=0,
weight_loader=weight_loader,
)
layer.register_parameter("weight", weight)
weight_scale = ChannelQuantScaleParameter(
data=torch.empty((sum(output_partition_sizes), 1), dtype=torch.float32),
output_dim=0,
weight_loader=weight_loader,
)
layer.register_parameter("weight_scale", weight_scale)
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
input_quant_args: Optional[list[torch.Tensor]] = None,
silu_quant_args: Optional[list[torch.Tensor]] = None
):
# if envs.USE_FUSED_RMS_QUANT and input_quant_args is not None:
# assert len(input_quant_args) == 2
# x_q, x_scale = input_quant_args
# elif envs.USE_FUSED_SILU_MUL_QUANT and silu_quant_args is not None:
# x_q, x_scale = silu_quant_args
# else:
x_q, x_scale = per_token_quant_int8(x)
if self.w8a8_strategy==1:
m=x_q.shape[0]
k=x_q.shape[1]
n=layer.weight.shape[1]
if len(W8A8_TRITONJSON.triton_json_dict)==0:
best_config=None
elif f"1_{n}_{k}" in W8A8_TRITONJSON.triton_json_dict:
if m<=16:
m_=m
elif m<=64:
m_= (m + 3) & -4 #取值到最近的4的倍数
elif m<=160:
m_=(m + 7) & -8
elif m<200: #256
m_=160
elif m<480: #512
m_=256
elif m<960: #1024
m_=512
elif m<2048:
m_=1024
elif m<4096:
m_=2048
elif m<6000:
m_=4096
else:
m_=8192
best_config=W8A8_TRITONJSON.triton_json_dict[f"{m_}_{n}_{k}"]
else:
best_config=None
#if best_config==None:
# print("m:{},n:{},k:{}".format(m,n,k))
# print("config not found!")
return ops.triton_scaled_mm(x_q,
layer.weight,
scale_a=x_scale,
scale_b=layer.weight_scale,
out_dtype=x.dtype,
bias=bias,best_config=best_config)
elif self.w8a8_strategy==2:
return ops.cutlass_scaled_mm(x_q,
layer.weight,
scale_a=x_scale,
scale_b=layer.weight_scale,
out_dtype=x.dtype,
bias=bias)
else:
return ops.rocblas_scaled_mm(x_q,
layer.weight,
scale_a=x_scale,
scale_b=layer.weight_scale,
out_dtype=x.dtype,
bias=bias)
class SlimQuantW4A8Int8MoEMethod:
"""MoE method for W4A8INT8.
Supports loading INT8 checkpoints with static weight scale and
dynamic/static activation scale.
Also supports loading quantized FP16/BF16 model checkpoints with dynamic
activation scaling. The weight scaling factor will be initialized after
the model weights are loaded.
Args:
quant_config: The quantization config.
"""
def __new__(cls, *args, **kwargs):
from sglang.srt.layers.moe.fused_moe_triton import (FusedMoE, FusedMoeWeightScaleSupported)
if not hasattr(cls, "_initialized"):
original_init = cls.__init__
new_cls = type(
cls.__name__,
(FusedMoEMethodBase,),
{
"__init__": original_init,
**{k: v for k, v in cls.__dict__.items() if k != "__dict__"},
},
)
obj = super(new_cls, new_cls).__new__(new_cls)
obj.__init__(*args, **kwargs)
return obj
return super().__new__(cls)
def __init__(self, quant_config):
self.quant_config = quant_config
self.tritonsingleton= W8a8GetCacheJSON()
def create_weights(
self,
layer: torch.nn.Module,
num_experts: int,
hidden_size: int,
intermediate_size: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
from sglang.srt.layers.moe.fused_moe_triton import (FusedMoE, FusedMoeWeightScaleSupported)
tp_size = get_tensor_model_parallel_world_size()
# WEIGHTS
w13_weight = torch.nn.Parameter(
torch.empty(
num_experts, 2 * intermediate_size, hidden_size//2, dtype=torch.int8
),
requires_grad=False,
)
layer.register_parameter("w13_weight", w13_weight)
set_weight_attrs(w13_weight, extra_weight_attrs)
w2_weight = torch.nn.Parameter(
torch.empty(num_experts, hidden_size, intermediate_size//2, dtype=torch.int8),
requires_grad=False,
)
layer.register_parameter("w2_weight", w2_weight)
set_weight_attrs(w2_weight, extra_weight_attrs)
w13_weight_scale = torch.nn.Parameter(
torch.ones(num_experts, 2 * intermediate_size, 1, dtype=torch.float32),
requires_grad=False,
)
w2_weight_scale = torch.nn.Parameter(
torch.ones(num_experts, hidden_size, 1, dtype=torch.float32),
requires_grad=False,
)
layer.register_parameter("w13_weight_scale", w13_weight_scale)
layer.register_parameter("w2_weight_scale", w2_weight_scale)
extra_weight_attrs.update(
{"quant_method": FusedMoeWeightScaleSupported.CHANNEL.value}
)
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
w13_input_scale = None
layer.register_parameter("w13_input_scale", w13_input_scale)
w2_input_scale = None
layer.register_parameter("w2_input_scale", w2_input_scale)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
E=layer.w13_weight.shape[0]
N1=layer.w13_weight.shape[1]
N2=layer.w2_weight.shape[1]
K=N1//2
if [E,N1,N2,K] not in self.tritonsingleton.moe_weight_shapes:
self.tritonsingleton.moe_weight_shapes.append([E,N1,N2,K])
TOPK= self.tritonsingleton.topk
json_file=self.tritonsingleton.get_moeint8json_name(E,N1,N2,K,TOPK,use_int4_w4a8=True)
configs_dict=self.tritonsingleton.get_moeint8_triton_cache(json_file,E,N1,N2,K,TOPK)
#warmup
if configs_dict:
self.tritonsingleton.triton_moejson_dict.update(configs_dict)
layer.w13_weight = Parameter(layer.w13_weight, requires_grad=False)
layer.w2_weight = Parameter(layer.w2_weight, requires_grad=False)
layer.w13_weight_scale = Parameter(
layer.w13_weight_scale.data, requires_grad=False
)
layer.w2_weight_scale = Parameter(
layer.w2_weight_scale.data, requires_grad=False
)
def create_moe_runner(
self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
):
self.moe_runner_config = moe_runner_config
self.runner = MoeRunner(MoeRunnerBackend.TRITON, moe_runner_config)
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool = False,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
enable_eplb: bool = False,
use_nn_moe: Optional[bool] = False,
routed_scaling_factor: Optional[float] = None,
use_fused_gate: Optional[bool] = False,
**_
) -> torch.Tensor:
from sglang.srt.layers.moe.fused_moe_triton import (FusedMoE, FusedMoeWeightScaleSupported)
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
if enable_eplb:
raise NotImplementedError(
"EPLB not supported for `SlimQuantW4A8Int8MoEMethod` yet.")
# Expert selection
topk_weights, topk_ids = FusedMoE.select_experts(
hidden_states=x,
router_logits=router_logits,
use_grouped_topk=use_grouped_topk,
top_k=top_k,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias,
routed_scaling_factor=routed_scaling_factor,
use_fused_gate=use_fused_gate
)
return fused_experts(
x,
layer.w13_weight,
layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=True,
use_int4_w4a8=True,
per_channel_quant=True,
activation=activation,
expert_map=expert_map,
apply_router_weight_on_input=apply_router_weight_on_input,
global_num_experts=global_num_experts,
w1_scale=(layer.w13_weight_scale),
w2_scale=(layer.w2_weight_scale),
a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale,
use_nn_moe=use_nn_moe,
)

View File

@@ -0,0 +1,319 @@
from typing import Any, Callable, Dict, List, Optional
# from sglang.srt.layers.moe.token_dispatcher.base import CombineInput
import torch
from sglang.srt import _custom_ops as ops
from sglang.srt.utils import set_weight_attrs
from sglang.srt.distributed import get_tensor_model_parallel_world_size
from torch.nn.parameter import Parameter
from sglang.srt.layers.linear import LinearBase
from sglang.srt.layers.quantization import QuantizationConfig
from sglang.srt.layers.quantization.w4a8_utils import w4a8_weight_repack_impl
from sglang.srt.layers.quantization.base_config import (FusedMoEMethodBase, QuantizeMethodBase)
from sglang.srt.layers.quantization.slimquant_w4a8 import SlimQuantW4A8Int8LinearMethod
from sglang.srt.layers.moe import MoeRunner, MoeRunnerBackend, MoeRunnerConfig
try:
from lmslim.layers.fused_moe.fuse_moe_w4a8_marlin import fused_experts_impl_w4a8_marlin
except Exception:
print("INFO: Please install lmslim if you want to infer the quantitative model of moe.\n")
class MarlinMoeWorkspace:
"""
Singleton manager for device-specific workspace buffers used by w4a8 Marlin-MoE.
global_reduce_buffer will take 1.5MB * cus (about 120MB for BW200) memoery in each device
"""
_instances = {}
def __new__(cls, device):
if device not in cls._instances:
instance = super().__new__(cls)
instance._initialized = False
cls._instances[device] = instance
return cls._instances[device]
def __init__(self, device):
if self._initialized:
return
sms = torch.cuda.get_device_properties(device).multi_processor_count
self.workspace = torch.zeros(
500, dtype=torch.int, device=device, requires_grad=False
)
self.global_reduce_buffer = torch.zeros(
sms * 6 * 128 * 512, dtype=torch.int, device=device, requires_grad=False
)
self._initialized = True
def get_buffers(self):
return self.workspace, self.global_reduce_buffer
def baseline_scaled_mm(a: torch.Tensor,
b: torch.Tensor,
scale_a: torch.Tensor,
scale_b: torch.Tensor,
out_dtype: torch.dtype,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
scales= scale_a* scale_b.T
gemmout= torch.mm(
a.to(dtype=torch.float32), b.to(dtype=torch.float32))
output = (scales *gemmout).to(out_dtype)
if bias is not None:
output = output + bias
return output.to(out_dtype)
class SlimQuantW4A8Int8MarlinConfig(QuantizationConfig):
"""Config class for W4A8 Int8 Quantization.
- Weight: static, per-channel, symmetric
- Activation: dynamic, per-token, symmetric
"""
def __init__(self):
pass
@classmethod
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
return [torch.float16, torch.bfloat16]
@classmethod
def get_min_capability(cls) -> int:
return 75
@classmethod
def get_name(self) -> str:
return "slimquant_w4a8_marlin"
@classmethod
def get_config_filenames(cls) -> List[str]:
return []
@classmethod
def from_config(cls, config: Dict[str, Any]) -> "SlimQuantW4A8Int8MarlinConfig":
return cls()
@classmethod
def override_quantization_method(
cls, hf_quant_cfg, user_quant) -> Optional[str]:
if hf_quant_cfg.get("quant_method") == "slimquant_w4a8" \
and user_quant == "slimquant_w4a8_marlin":
return cls.get_name()
return None
def get_quant_method(
self,
layer: torch.nn.Module,
prefix: str,
) -> Optional["QuantizeMethodBase"]:
from sglang.srt.layers.moe.fused_moe_triton import (FusedMoE, FusedMoeWeightScaleSupported)
if isinstance(layer, LinearBase):
return SlimQuantW4A8Int8LinearMethod(self)
elif isinstance(layer, FusedMoE):
return SlimQuantW4A8Int8MarlinMoEMethod(self)
return None
def get_scaled_act_names(self) -> List[str]:
return []
class SlimQuantW4A8Int8MarlinMoEMethod:
"""MoE method for W4A8INT8 Marlin.
Supports loading INT8 checkpoints with static weight scale and
dynamic/static activation scale.
Args:
quant_config: The quantization config.
"""
def __new__(cls, *args, **kwargs):
from sglang.srt.layers.moe.fused_moe_triton import (FusedMoE, FusedMoeWeightScaleSupported)
if not hasattr(cls, "_initialized"):
original_init = cls.__init__
new_cls = type(
cls.__name__,
(FusedMoEMethodBase,),
{
"__init__": original_init,
**{k: v for k, v in cls.__dict__.items() if k != "__dict__"},
},
)
obj = super(new_cls, new_cls).__new__(new_cls)
obj.__init__(*args, **kwargs)
return obj
return super().__new__(cls)
def __init__(self, quant_config):
self.quant_config = quant_config
def create_weights(
self,
layer: torch.nn.Module,
num_experts: int,
hidden_size: int,
intermediate_size_per_partition: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
from sglang.srt.layers.moe.fused_moe_triton import (FusedMoE, FusedMoeWeightScaleSupported)
tp_size = get_tensor_model_parallel_world_size()
intermediate_size = intermediate_size_per_partition
# WEIGHTS
w13_weight = torch.nn.Parameter(
torch.empty(
num_experts, 2 * intermediate_size, hidden_size//2, dtype=torch.int8
),
requires_grad=False,
)
layer.register_parameter("w13_weight", w13_weight)
set_weight_attrs(w13_weight, extra_weight_attrs)
w2_weight = torch.nn.Parameter(
torch.empty(num_experts, hidden_size, intermediate_size//2, dtype=torch.int8),
requires_grad=False,
)
layer.register_parameter("w2_weight", w2_weight)
set_weight_attrs(w2_weight, extra_weight_attrs)
w13_weight_scale = torch.nn.Parameter(
torch.ones(num_experts, 2 * intermediate_size, 1, dtype=torch.float32),
requires_grad=False,
)
w2_weight_scale = torch.nn.Parameter(
torch.ones(num_experts, hidden_size, 1, dtype=torch.float32),
requires_grad=False,
)
layer.register_parameter("w13_weight_scale", w13_weight_scale)
layer.register_parameter("w2_weight_scale", w2_weight_scale)
extra_weight_attrs.update(
{"quant_method": FusedMoeWeightScaleSupported.CHANNEL.value}
)
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
w13_input_scale = None
layer.register_parameter("w13_input_scale", w13_input_scale)
w2_input_scale = None
layer.register_parameter("w2_input_scale", w2_input_scale)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
layer.w13_weight_scale = Parameter(
layer.w13_weight_scale.data, requires_grad=False
)
layer.w2_weight_scale = Parameter(
layer.w2_weight_scale.data, requires_grad=False
)
layer.w13_weight = Parameter(w4a8_weight_repack_impl(layer.w13_weight), requires_grad=False)
layer.w2_weight = Parameter(w4a8_weight_repack_impl(layer.w2_weight), requires_grad=False)
def create_moe_runner(
self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
):
self.moe_runner_config = moe_runner_config
self.runner = MoeRunner(MoeRunnerBackend.TRITON, moe_runner_config)
def apply(
self,
layer: torch.nn.Module,
dispatch_output,
) :
from sglang.srt.layers.moe.token_dispatcher.standard import StandardCombineInput
x = dispatch_output.hidden_states
topk_output = dispatch_output.topk_output
from sglang.srt.layers.moe.topk import apply_topk_weights_cpu
topk_weights, topk_ids, _ = topk_output
x, topk_weights = apply_topk_weights_cpu(
self.moe_runner_config.apply_router_weight_on_input, topk_weights, x
)
workspace, global_reduce_buffer = MarlinMoeWorkspace(x.device).get_buffers()
output = fused_experts_impl_w4a8_marlin(
x,
layer.w13_weight,
layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
workspace=workspace,
global_reduce_buffer=global_reduce_buffer,
inplace=True,
use_int4_w4a8=True,
per_channel_quant=True,
activation=layer.moe_runner_config.activation,
# expert_map=layer.expert_map_gpu,
apply_router_weight_on_input=self.moe_runner_config.apply_router_weight_on_input,
global_num_experts=layer.moe_runner_config.num_experts,
w1_scale=(layer.w13_weight_scale),
w2_scale=(layer.w2_weight_scale),
a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale,
use_nn_moe=False,
)
return StandardCombineInput(hidden_states=output)
# def _apply(
# self,
# layer: torch.nn.Module,
# x: torch.Tensor,
# router_logits: torch.Tensor,
# top_k: int,
# #renormalize: bool,
# #use_grouped_topk: bool = False,
# topk_group: Optional[int] = None,
# num_expert_group: Optional[int] = None,
# global_num_experts: int = -1,
# expert_map: Optional[torch.Tensor] = None,
# custom_routing_function: Optional[Callable] = None,
# scoring_func: str = "softmax",
# e_score_correction_bias: Optional[torch.Tensor] = None,
# apply_router_weight_on_input: bool = False,
# activation: str = "silu",
# enable_eplb: bool = False,
# use_nn_moe: Optional[bool] = False,
# routed_scaling_factor: Optional[float] = None,
# use_fused_gate: Optional[bool] = False,
# **_
# ) -> torch.Tensor:
# from sglang.srt.layers.moe.fused_moe_triton import (FusedMoE, FusedMoeWeightScaleSupported)
# from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
# if enable_eplb:
# raise NotImplementedError(
# "EPLB not supported for `SlimQuantW4A8Int8MarlinMoEMethod` yet.")
# # Expert selection
# topk_weights, topk_ids = FusedMoE.select_experts(
# hidden_states=x,
# router_logits=router_logits,
# #use_grouped_topk=use_grouped_topk,
# top_k=top_k,
# #renormalize=renormalize,
# topk_group=topk_group,
# num_expert_group=num_expert_group,
# custom_routing_function=custom_routing_function,
# scoring_func=scoring_func,
# e_score_correction_bias=e_score_correction_bias,
# routed_scaling_factor=routed_scaling_factor,
# use_fused_gate=use_fused_gate
# )
# workspace, global_reduce_buffer = MarlinMoeWorkspace(x.device).get_buffers()
# return fused_experts_impl_w4a8_marlin(
# x,
# layer.w13_weight,
# layer.w2_weight,
# topk_weights=topk_weights,
# topk_ids=topk_ids,
# workspace=workspace,
# global_reduce_buffer=global_reduce_buffer,
# inplace=True,
# use_int4_w4a8=True,
# per_channel_quant=True,
# activation=activation,
# expert_map=expert_map,
# apply_router_weight_on_input=apply_router_weight_on_input,
# global_num_experts=global_num_experts,
# w1_scale=(layer.w13_weight_scale),
# w2_scale=(layer.w2_weight_scale),
# a1_scale=layer.w13_input_scale,
# a2_scale=layer.w2_input_scale,
# use_nn_moe=use_nn_moe,
# )

View File

@@ -0,0 +1,92 @@
import torch
import numpy as np
try:
from lightop import awq_marlin_repack_w4a8
use_lightop = False
except Exception:
use_lightop = False
def unpack_int8_to_int4(tensor_int8: torch.Tensor) -> torch.Tensor:
"""
将[N, K//2]大小的torch.int8 Tensor转换为[N, K]大小的torch.int32 Tensor。
每个int8包含两个int4分别提取到int32的低4位其余位为0。
Args:
tensor_int8 (torch.Tensor): 输入张量,形状为[N, K//2]类型为torch.int8。
Returns:
torch.Tensor: 输出张量,形状为[N, K]类型为torch.int32。
"""
if tensor_int8.dtype != torch.int8:
raise ValueError("Input tensor must be of type torch.int8")
N, K_half = tensor_int8.shape
tensor_uint8 = tensor_int8.to(torch.uint8)
high4 = tensor_uint8 & 0x0F
low4 = (tensor_uint8 >> 4) & 0x0F
unpacked = torch.empty((N, K_half * 2), dtype=torch.int32, device=tensor_int8.device)
unpacked[:, 0::2] = low4.to(torch.int32)
unpacked[:, 1::2] = high4.to(torch.int32)
return unpacked
def get_weight_perms(interleave: bool=True):
perm = []
for i in range(64):
for col in range(4):
cur_col = (i % 16) * 4 + col
for row in range(8):
cur_row = (i // 16) * 8 + row
cur_idx = cur_row * 64 + cur_col
perm.append(cur_idx)
perm = np.array(perm)
if interleave:
interleave = np.array([4, 0, 5, 1, 6, 2, 7, 3])
perm = perm.reshape((-1, 8))[:, interleave].ravel()
perm = torch.from_numpy(perm)
return perm
def marlin_weights(q_w,weight_perm,k_tile=32,n_tile=64,pack_factor=8):
size_k, size_n = q_w.shape
q_w = q_w.reshape((size_k // k_tile, k_tile, size_n // n_tile, n_tile))
q_w = q_w.permute((0, 2, 1, 3))
q_w = q_w.reshape((size_k // k_tile, size_n * k_tile))
q_w = q_w.reshape((-1, weight_perm.numel()))[:, weight_perm].reshape(q_w.shape)
orig_device = q_w.device
q_w = q_w.contiguous().to(torch.int32)
M, N = q_w.shape
assert N % pack_factor == 0, f"size_n ({N}) must be divisible by pack_factor ({pack_factor})"
q_packed = torch.zeros((M, N // pack_factor), dtype=torch.int32, device=orig_device)
for i in range(pack_factor):
q_packed += q_w[:, i::pack_factor] << (4 * i)
return q_packed
def w4a8_2_marlin_weight(w4a8_w):
full_w4a8_w = unpack_int8_to_int4(w4a8_w)
full_w4a8_w = full_w4a8_w.T
weight_perm = get_weight_perms()
marlin_q_w = marlin_weights(full_w4a8_w, weight_perm, k_tile=32, n_tile=64, pack_factor=8)
return marlin_q_w
def w4a8_weight_repack_impl(input):
if use_lightop:
size_batch = input.shape[0]
size_n = input.shape[1]
size_k = input.shape[2] * 2
output = torch.zeros((size_batch, size_k // 32, size_n * 4), device=input.device, dtype=torch.int32)
awq_marlin_repack_w4a8(input, output, size_batch, size_k, size_n)
else:
w_marlin_list = []
for e in range(input.shape[0]):
w_marlin_in = w4a8_2_marlin_weight(input[e])
w_marlin_list.append(w_marlin_in)
output = torch.stack(w_marlin_list, dim=0)
return output

View File

@@ -22,7 +22,8 @@ from sglang.srt.layers.quantization.base_config import (
QuantizeMethodBase, QuantizeMethodBase,
) )
from sglang.srt.layers.quantization.compressed_tensors.utils import should_ignore_layer from sglang.srt.layers.quantization.compressed_tensors.utils import should_ignore_layer
from sglang.srt.layers.quantization.int8_kernel import per_token_quant_int8 # from sglang.srt.layers.quantization.int8_kernel import per_token_quant_int8
from lmslim.layers.gemm.int8_utils import per_token_quant_int8
from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod
from sglang.srt.utils import ( from sglang.srt.utils import (
apply_module_patch, apply_module_patch,
@@ -39,6 +40,8 @@ if TYPE_CHECKING:
CombineInput, CombineInput,
StandardDispatchOutput, StandardDispatchOutput,
) )
from lmslim import quant_ops
_is_cuda = is_cuda() _is_cuda = is_cuda()
_is_cpu_amx_available = cpu_has_amx_support() _is_cpu_amx_available = cpu_has_amx_support()
@@ -405,7 +408,7 @@ class W8A8Int8LinearMethod(LinearMethodBase):
x_scale_2d = x_scale.view(-1, x_scale.shape[-1]) x_scale_2d = x_scale.view(-1, x_scale.shape[-1])
output_shape = [*x_q.shape[:-1], layer.weight.shape[1]] output_shape = [*x_q.shape[:-1], layer.weight.shape[1]]
output = int8_scaled_mm( output = quant_ops.triton_scaled_mm(
x_q_2d, x_q_2d,
layer.weight, layer.weight,
x_scale_2d, x_scale_2d,

View File

@@ -1618,7 +1618,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
self.seq_lens_cpu = self.seq_lens_cpu[keep_indices] self.seq_lens_cpu = self.seq_lens_cpu[keep_indices]
self.orig_seq_lens = self.orig_seq_lens[keep_indices_device] self.orig_seq_lens = self.orig_seq_lens[keep_indices_device]
self.out_cache_loc = None self.out_cache_loc = None
self.seq_lens_sum = self.seq_lens.sum().item() self.seq_lens_sum = self.seq_lens.sum()
self.output_ids = self.output_ids[keep_indices_device] self.output_ids = self.output_ids[keep_indices_device]
self.return_logprob = any(req.return_logprob for req in self.reqs) self.return_logprob = any(req.return_logprob for req in self.reqs)
if self.return_logprob: if self.return_logprob:

View File

@@ -165,6 +165,7 @@ MLA_ATTENTION_BACKENDS = [
"triton", "triton",
"flashmla", "flashmla",
"cutlass_mla", "cutlass_mla",
"dcu_mla",
"trtllm_mla", "trtllm_mla",
"ascend", "ascend",
"nsa", "nsa",
@@ -203,7 +204,7 @@ _is_xpu_xmx_available = xpu_has_xmx_support()
SGLANG_CI_SMALL_KV_SIZE = os.getenv("SGLANG_CI_SMALL_KV_SIZE", None) SGLANG_CI_SMALL_KV_SIZE = os.getenv("SGLANG_CI_SMALL_KV_SIZE", None)
# Detect stragger ranks in model loading # Detect stragger ranks in model loading
UNBALANCED_MODEL_LOADING_TIMEOUT_S = 300 UNBALANCED_MODEL_LOADING_TIMEOUT_S = 3600
# the ratio of mamba cache pool size to max_running_requests, it will be safe when it is larger than 2 (yizhang2077) # the ratio of mamba cache pool size to max_running_requests, it will be safe when it is larger than 2 (yizhang2077)
MAMBA_CACHE_SIZE_MAX_RUNNING_REQUESTS_RATIO = 3 MAMBA_CACHE_SIZE_MAX_RUNNING_REQUESTS_RATIO = 3

View File

@@ -342,6 +342,10 @@ def handle_attention_flashmla(attn, forward_batch):
return _handle_attention_backend(attn, forward_batch, "flashmla") return _handle_attention_backend(attn, forward_batch, "flashmla")
def handle_attention_dcu_mla(attn, forward_batch):
return _handle_attention_backend(attn, forward_batch, "dcu_mla")
def handle_attention_cutlass_mla(attn, forward_batch): def handle_attention_cutlass_mla(attn, forward_batch):
return _handle_attention_backend(attn, forward_batch, "cutlass_mla") return _handle_attention_backend(attn, forward_batch, "cutlass_mla")
@@ -3577,6 +3581,7 @@ AttentionBackendRegistry.register("ascend", handle_attention_ascend)
AttentionBackendRegistry.register("flashinfer", handle_attention_flashinfer) AttentionBackendRegistry.register("flashinfer", handle_attention_flashinfer)
AttentionBackendRegistry.register("fa3", handle_attention_fa3) AttentionBackendRegistry.register("fa3", handle_attention_fa3)
AttentionBackendRegistry.register("flashmla", handle_attention_flashmla) AttentionBackendRegistry.register("flashmla", handle_attention_flashmla)
AttentionBackendRegistry.register("dcu_mla", handle_attention_dcu_mla)
AttentionBackendRegistry.register("cutlass_mla", handle_attention_cutlass_mla) AttentionBackendRegistry.register("cutlass_mla", handle_attention_cutlass_mla)
AttentionBackendRegistry.register("fa4", handle_attention_fa4) AttentionBackendRegistry.register("fa4", handle_attention_fa4)
AttentionBackendRegistry.register("trtllm_mla", handle_attention_trtllm_mla) AttentionBackendRegistry.register("trtllm_mla", handle_attention_trtllm_mla)

View File

@@ -396,7 +396,7 @@ class Qwen3GatedDeltaNet(nn.Module):
def _forward_input_proj(self, hidden_states: torch.Tensor): def _forward_input_proj(self, hidden_states: torch.Tensor):
DUAL_STREAM_TOKEN_THRESHOLD = 1024 if not _is_npu else 0 DUAL_STREAM_TOKEN_THRESHOLD = 1024 if not _is_npu else 0
seq_len, _ = hidden_states.shape seq_len, _ = hidden_states.shape
if seq_len < DUAL_STREAM_TOKEN_THRESHOLD: if seq_len < DUAL_STREAM_TOKEN_THRESHOLD and self.alt_stream is not None:
current_stream = torch.cuda.current_stream() current_stream = torch.cuda.current_stream()
self.alt_stream.wait_stream(current_stream) self.alt_stream.wait_stream(current_stream)
projected_states_qkvz, _ = self.in_proj_qkvz(hidden_states) projected_states_qkvz, _ = self.in_proj_qkvz(hidden_states)

View File

@@ -0,0 +1,58 @@
from ctypes import *
import os
import time
import threading
class Prof:
def __init__(self):
self.use_roctx = os.getenv('SGLANG_HIP_PROF') is not None
if self.use_roctx:
self.lib = cdll.LoadLibrary("libroctracer64.so")
self.lib.roctxRangePushA.argtypes = [c_char_p]
self.lib.roctxRangePushA.restype = c_int
self.lib.roctxRangePop.restype = c_int
self.tm = time.perf_counter()
self.push_depth = {}
def StartTracer(self):
if self.use_roctx:
if self.lib is None:
self.lib = cdll.LoadLibrary("libroctracer64.so")
self.lib.roctracer_start()
self.roc_tracer_flag = True
def StopTracer(self):
if self.use_roctx:
if self.lib is None:
self.lib = cdll.LoadLibrary("libroctracer64.so")
self.lib.roctracer_stop()
self.roc_tracer_flag = False
def thread_depth_add(self, num):
current_thread = threading.current_thread()
thread_id = current_thread.ident
if thread_id not in self.push_depth.keys():
self.push_depth[thread_id] = 0
if num < 0 and self.push_depth[thread_id] == 0:
return False
self.push_depth[thread_id] += num
return True
def ProfRangePush(self, message):
if profile.use_roctx and self.roc_tracer_flag:
profile.lib.roctxRangePushA(message.encode('utf-8'))
profile.lib.roctxRangePushA(message.encode('utf-8'))
self.thread_depth_add(1)
def ProfRangePop(self):
if profile.use_roctx and self.roc_tracer_flag:
if not self.thread_depth_add(-1):
return
profile.lib.roctxRangePop()
def ProfRangeAutoPush(self, message):
self.ProfRangePop()
self.ProfRangePush(message)
profile = Prof()

View File

@@ -93,6 +93,7 @@ QUANTIZATION_CHOICES = [
"w4afp8", "w4afp8",
"mxfp4", "mxfp4",
"compressed-tensors", # for Ktransformers "compressed-tensors", # for Ktransformers
"slimquant_w4a8_marlin",
] ]
ATTENTION_BACKEND_CHOICES = [ ATTENTION_BACKEND_CHOICES = [
@@ -101,6 +102,8 @@ ATTENTION_BACKEND_CHOICES = [
"torch_native", "torch_native",
"flex_attention", "flex_attention",
"nsa", "nsa",
# ransplant from vllm
"dcu_mla",
# NVIDIA specific # NVIDIA specific
"cutlass_mla", "cutlass_mla",
"fa3", "fa3",
@@ -1076,9 +1079,11 @@ class ServerArgs:
if ( if (
self.attention_backend == "flashmla" self.attention_backend == "flashmla"
or self.decode_attention_backend == "flashmla" or self.decode_attention_backend == "flashmla"
or self.attention_backend == "dcu_mla"
or self.decode_attention_backend == "dcu_mla"
): ):
logger.warning( logger.warning(
"FlashMLA only supports a page_size of 64, change page_size to 64." "FlashMLA/DCU MLA only supports a page_size of 64, change page_size to 64."
) )
self.page_size = 64 self.page_size = 64

View File

@@ -25,7 +25,7 @@
#define INTRIN_M 16 #define INTRIN_M 16
#define INTRIN_N 16 #define INTRIN_N 16
#define INTRIN_K 32 #define INTRIN_K 32
#define WARP_SIZE 32 #define WARP_SIZE 64
#define SMEM_PAD_A 0 #define SMEM_PAD_A 0
#define SMEM_PAD_B 0 #define SMEM_PAD_B 0
#define PACK_SIZE 16 #define PACK_SIZE 16

View File

@@ -25,7 +25,7 @@
#define INTRIN_M 16 #define INTRIN_M 16
#define INTRIN_N 16 #define INTRIN_N 16
#define INTRIN_K 32 #define INTRIN_K 32
#define WARP_SIZE 32 #define WARP_SIZE 64
#define SMEM_PAD_A 0 #define SMEM_PAD_A 0
#define SMEM_PAD_B 0 #define SMEM_PAD_B 0
#define PACK_SIZE 16 #define PACK_SIZE 16

View File

@@ -5,7 +5,7 @@
#include <cstdint> #include <cstdint>
#ifndef USE_ROCM #ifndef USE_ROCM
#define WARP_SIZE 32 #define WARP_SIZE 64
#include "pytorch_extension_utils.h" #include "pytorch_extension_utils.h"
#else #else
#include "pytorch_extension_utils_rocm.h" #include "pytorch_extension_utils_rocm.h"

View File

@@ -3,7 +3,7 @@
// copied from https://github.com/ggerganov/llama.cpp/blob/b2899/ggml-common.h // copied from https://github.com/ggerganov/llama.cpp/blob/b2899/ggml-common.h
#define QK_K 256 #define QK_K 256
#define K_QUANTS_PER_ITERATION 2 #define K_QUANTS_PER_ITERATION 2
#define WARP_SIZE_GGUF 32 #define WARP_SIZE_GGUF 64
#define K_SCALE_SIZE 12 #define K_SCALE_SIZE 12
#define CUDA_DEQUANTIZE_BLOCK_SIZE 256 #define CUDA_DEQUANTIZE_BLOCK_SIZE 256
#define CUDA_QUANTIZE_BLOCK_SIZE 256 #define CUDA_QUANTIZE_BLOCK_SIZE 256

View File

@@ -340,7 +340,7 @@ inline bool getEnvEnablePDL() {
#define CEILDIV(x, y) (((x) + (y) - 1) / (y)) #define CEILDIV(x, y) (((x) + (y) - 1) / (y))
#ifndef USE_ROCM #ifndef USE_ROCM
#define WARP_SIZE 32 #define WARP_SIZE 64
#else #else
#if defined(__GFX9__) || !defined(__HIP_DEVICE_COMPILE__) #if defined(__GFX9__) || !defined(__HIP_DEVICE_COMPILE__)
#define WARP_SIZE 64 #define WARP_SIZE 64