Compare commits
2 Commits
v0.5.4_dev
...
v0.5.4
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b4dff7f5ef | ||
|
|
c0352f4aab |
@@ -839,12 +839,10 @@ class BenchmarkMetrics:
|
||||
mean_ttft_ms: float
|
||||
median_ttft_ms: float
|
||||
std_ttft_ms: float
|
||||
p95_ttft_ms: float
|
||||
p99_ttft_ms: float
|
||||
mean_tpot_ms: float
|
||||
median_tpot_ms: float
|
||||
std_tpot_ms: float
|
||||
p95_tpot_ms: float
|
||||
p99_tpot_ms: float
|
||||
mean_itl_ms: float
|
||||
median_itl_ms: float
|
||||
@@ -1667,12 +1665,10 @@ def calculate_metrics(
|
||||
* 1000, # ttfts is empty if streaming is not supported by backend
|
||||
median_ttft_ms=np.median(ttfts or 0) * 1000,
|
||||
std_ttft_ms=np.std(ttfts or 0) * 1000,
|
||||
p95_ttft_ms=np.percentile(ttfts or 0, 95) * 1000,
|
||||
p99_ttft_ms=np.percentile(ttfts or 0, 99) * 1000,
|
||||
mean_tpot_ms=np.mean(tpots or 0) * 1000,
|
||||
median_tpot_ms=np.median(tpots or 0) * 1000,
|
||||
std_tpot_ms=np.std(tpots or 0) * 1000,
|
||||
p95_tpot_ms=np.percentile(tpots or 0, 95) * 1000,
|
||||
p99_tpot_ms=np.percentile(tpots or 0, 99) * 1000,
|
||||
mean_itl_ms=np.mean(itls or 0) * 1000,
|
||||
median_itl_ms=np.median(itls or 0) * 1000,
|
||||
@@ -1978,12 +1974,6 @@ async def benchmark(
|
||||
print("{:<40} {:<10.2f}".format("Mean TTFT (ms):", metrics.mean_ttft_ms))
|
||||
print("{:<40} {:<10.2f}".format("Median TTFT (ms):", metrics.median_ttft_ms))
|
||||
print("{:<40} {:<10.2f}".format("P99 TTFT (ms):", metrics.p99_ttft_ms))
|
||||
print("{:<40} {:<10.2f}".format("P95 TTFT (ms):", metrics.p95_ttft_ms))
|
||||
print("{s:{c}^{n}}".format(s="Time per Output Token (excl. 1st token)", n=50, c="-"))
|
||||
print("{:<40} {:<10.2f}".format("Mean TPOT (ms):", metrics.mean_tpot_ms))
|
||||
print("{:<40} {:<10.2f}".format("Median TPOT (ms):", metrics.median_tpot_ms))
|
||||
print("{:<40} {:<10.2f}".format("P99 TPOT (ms):", metrics.p99_tpot_ms))
|
||||
print("{:<40} {:<10.2f}".format("P95 TPOT (ms):", metrics.p95_tpot_ms))
|
||||
print("{s:{c}^{n}}".format(s="Inter-Token Latency", n=50, c="-"))
|
||||
print("{:<40} {:<10.2f}".format("Mean ITL (ms):", metrics.mean_itl_ms))
|
||||
print("{:<40} {:<10.2f}".format("Median ITL (ms):", metrics.median_itl_ms))
|
||||
|
||||
@@ -19,9 +19,6 @@ logger = logging.getLogger(__name__)
|
||||
use_vllm_custom_allreduce = get_bool_env_var(
|
||||
"USE_VLLM_CUSTOM_ALLREDUCE", default="false"
|
||||
)
|
||||
use_dcu_custom_allreduce= get_bool_env_var(
|
||||
"USE_DCU_CUSTOM_ALLREDUCE", default="false"
|
||||
)
|
||||
|
||||
if not is_hpu():
|
||||
# ROCm does not use vllm custom allreduce
|
||||
@@ -36,11 +33,6 @@ if not is_hpu():
|
||||
except ImportError as e:
|
||||
logger.warning("Failed to import from custom_ar with %r", e)
|
||||
|
||||
if use_dcu_custom_allreduce:
|
||||
try:
|
||||
import vllm._C
|
||||
except ImportError as e:
|
||||
logger.warning("Failed to import from vllm._C with %r", e)
|
||||
|
||||
if not is_hip() and not is_npu():
|
||||
if use_vllm_custom_allreduce:
|
||||
@@ -83,79 +75,8 @@ if not is_hip() and not is_npu():
|
||||
) -> None:
|
||||
custom_op.register_graph_buffers(fa, handles, offsets)
|
||||
|
||||
elif is_hip and use_dcu_custom_allreduce:
|
||||
# custom ar
|
||||
def init_custom_ar(ipc_tensors: list[torch.Tensor], rank_data: torch.Tensor,
|
||||
rank: int, fully_connected: bool) -> int:
|
||||
return torch.ops._C_custom_ar.init_custom_ar(ipc_tensors, rank_data, rank,
|
||||
fully_connected)
|
||||
|
||||
|
||||
def all_reduce(fa: int, inp: torch.Tensor, out: torch.Tensor, reg_buffer: int,
|
||||
reg_buffer_sz_bytes: int) -> None:
|
||||
torch.ops._C_custom_ar.all_reduce(fa, inp, out, reg_buffer,
|
||||
reg_buffer_sz_bytes)
|
||||
|
||||
|
||||
def dispose(fa: int) -> None:
|
||||
torch.ops._C_custom_ar.dispose(fa)
|
||||
|
||||
|
||||
def meta_size() -> int:
|
||||
return torch.ops._C_custom_ar.meta_size()
|
||||
|
||||
|
||||
def register_buffer(fa: int, ipc_tensors: list[int]) -> None:
|
||||
return torch.ops._C_custom_ar.register_buffer(fa, ipc_tensors)
|
||||
|
||||
|
||||
def get_graph_buffer_ipc_meta(fa: int) -> tuple[list[int], list[int]]:
|
||||
return torch.ops._C_custom_ar.get_graph_buffer_ipc_meta(fa)
|
||||
|
||||
|
||||
def register_graph_buffers(fa: int, handles: list[list[int]],
|
||||
offsets: list[list[int]]) -> None:
|
||||
torch.ops._C_custom_ar.register_graph_buffers(fa, handles, offsets)
|
||||
|
||||
|
||||
def allocate_shared_buffer_and_handle(size: int) -> tuple[int, torch.Tensor]:
|
||||
return torch.ops._C_custom_ar.allocate_shared_buffer_and_handle(size)
|
||||
|
||||
|
||||
def open_mem_handle(mem_handle: torch.Tensor):
|
||||
return torch.ops._C_custom_ar.open_mem_handle(mem_handle)
|
||||
|
||||
|
||||
def free_shared_buffer(ptr: int) -> None:
|
||||
torch.ops._C_custom_ar.free_shared_buffer(ptr)
|
||||
|
||||
|
||||
def read_cache(
|
||||
keys: torch.Tensor,
|
||||
values: torch.Tensor,
|
||||
key_caches: list[torch.Tensor],
|
||||
value_caches: list[torch.Tensor],
|
||||
slot_mapping: torch.Tensor,
|
||||
kv_cache_dtype: str
|
||||
) -> None:
|
||||
torch.ops._C_cache_ops.read_cache(keys, values, key_caches,
|
||||
value_caches, slot_mapping,
|
||||
kv_cache_dtype)
|
||||
|
||||
def write_cache_multi_layers(
|
||||
keys: torch.Tensor,
|
||||
values: torch.Tensor,
|
||||
key_caches: list[torch.Tensor],
|
||||
value_caches: list[torch.Tensor],
|
||||
slot_mapping: torch.Tensor,
|
||||
kv_cache_dtype: str
|
||||
) -> None:
|
||||
torch.ops._C_cache_ops.write_cache_multi_layers(keys, values, key_caches,
|
||||
value_caches, slot_mapping,
|
||||
kv_cache_dtype)
|
||||
|
||||
else:
|
||||
# sgl_kernel ROCM custom allreduce
|
||||
# ROCM custom allreduce
|
||||
|
||||
def init_custom_ar(
|
||||
meta: torch.Tensor,
|
||||
|
||||
@@ -614,8 +614,6 @@ class ModelConfig:
|
||||
"petit_nvfp4",
|
||||
"quark",
|
||||
"mxfp4",
|
||||
"slimquant_w4a8_marlin",
|
||||
"w8a8_int8",
|
||||
]
|
||||
optimized_quantization_methods = [
|
||||
"fp8",
|
||||
@@ -635,7 +633,6 @@ class ModelConfig:
|
||||
"qoq",
|
||||
"w4afp8",
|
||||
"petit_nvfp4",
|
||||
"slimquant_w4a8_marlin",
|
||||
]
|
||||
compatible_quantization_methods = {
|
||||
"modelopt_fp4": ["modelopt"],
|
||||
|
||||
@@ -30,8 +30,6 @@ try:
|
||||
if ops.use_vllm_custom_allreduce and not _is_hip:
|
||||
# Use vLLM custom allreduce
|
||||
ops.meta_size()
|
||||
elif ops.use_dcu_custom_allreduce:
|
||||
ops.meta_size()
|
||||
else:
|
||||
# Use custom allreduce from sgl kernel (ROCM and TRT-LLM)
|
||||
import sgl_kernel # noqa: F401
|
||||
@@ -421,274 +419,3 @@ class CustomAllreduce:
|
||||
|
||||
def __del__(self):
|
||||
self.close()
|
||||
|
||||
class DCUCustomAllreduce:
|
||||
|
||||
_SUPPORTED_WORLD_SIZES = [2, 4, 6, 8, 16]
|
||||
|
||||
# max_size: max supported allreduce size
|
||||
def __init__(self,
|
||||
group: ProcessGroup,
|
||||
device: Union[int, str, torch.device],
|
||||
max_size=8192 * 512) -> None:
|
||||
"""
|
||||
Args:
|
||||
group: the process group to work on. If None, it will use the
|
||||
default process group.
|
||||
device: the device to bind the CustomAllreduce to. If None,
|
||||
it will be bind to f"cuda:{local_rank}".
|
||||
It is the caller's responsibility to make sure each communicator
|
||||
is bind to a unique device, and all communicators in this group
|
||||
are in the same node.
|
||||
"""
|
||||
self._IS_CAPTURING = False
|
||||
self.disabled = True
|
||||
|
||||
if not custom_ar:
|
||||
# disable because of missing custom allreduce library
|
||||
# e.g. in a non-GPU environment
|
||||
logger.info("Custom allreduce is disabled because "
|
||||
"of missing custom allreduce library")
|
||||
return
|
||||
|
||||
self.group = group
|
||||
|
||||
assert dist.get_backend(group) != dist.Backend.NCCL, (
|
||||
"CustomAllreduce should be attached to a non-NCCL group.")
|
||||
|
||||
if not all(in_the_same_node_as(group, source_rank=0)):
|
||||
# No need to initialize custom allreduce for multi-node case.
|
||||
logger.warning(
|
||||
"Custom allreduce is disabled because this process group"
|
||||
" spans across nodes.")
|
||||
return
|
||||
|
||||
rank = dist.get_rank(group=self.group)
|
||||
self.rank = rank
|
||||
world_size = dist.get_world_size(group=self.group)
|
||||
|
||||
# if world_size > envs.VLLM_CUSTOM_ALLREDUCE_SUPPORTED_WORLDSIZE_MAX:
|
||||
if world_size > 16:
|
||||
return
|
||||
|
||||
if world_size == 1:
|
||||
# No need to initialize custom allreduce for single GPU case.
|
||||
return
|
||||
|
||||
if world_size not in CustomAllreduce._SUPPORTED_WORLD_SIZES:
|
||||
logger.warning(
|
||||
"Custom allreduce is disabled due to an unsupported world"
|
||||
" size: %d. Supported world sizes: %s. To silence this "
|
||||
"warning, specify disable_custom_all_reduce=True explicitly.",
|
||||
world_size, str(CustomAllreduce._SUPPORTED_WORLD_SIZES))
|
||||
return
|
||||
|
||||
if isinstance(device, int):
|
||||
device = torch.device(f"cuda:{device}")
|
||||
elif isinstance(device, str):
|
||||
device = torch.device(device)
|
||||
# now `device` is a `torch.device` object
|
||||
assert isinstance(device, torch.device)
|
||||
self.device = device
|
||||
|
||||
cuda_visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES", None)
|
||||
if cuda_visible_devices:
|
||||
device_ids = list(map(int, cuda_visible_devices.split(",")))
|
||||
else:
|
||||
device_ids = list(range(torch.cuda.device_count()))
|
||||
|
||||
physical_device_id = device_ids[device.index]
|
||||
tensor = torch.tensor([physical_device_id],
|
||||
dtype=torch.int,
|
||||
device="cpu")
|
||||
gather_list = [
|
||||
torch.tensor([0], dtype=torch.int, device="cpu")
|
||||
for _ in range(world_size)
|
||||
]
|
||||
dist.all_gather(gather_list, tensor, group=self.group)
|
||||
physical_device_ids = [t.item() for t in gather_list]
|
||||
|
||||
# test nvlink first, this will filter out most of the cases
|
||||
# where custom allreduce is not supported
|
||||
# this checks hardware and driver support for NVLink
|
||||
# assert current_platform.is_cuda_alike()
|
||||
# fully_connected = current_platform.is_fully_connected(
|
||||
# physical_device_ids)
|
||||
if _is_cuda or _is_hip:
|
||||
fully_connected = is_full_nvlink(physical_device_ids, world_size)
|
||||
|
||||
# if world_size > 2 and not fully_connected:
|
||||
if not fully_connected:
|
||||
max_size = 32 * 8192 * 2
|
||||
# if not envs.VLLM_PCIE_USE_CUSTOM_ALLREDUCE:
|
||||
# logger.warning(
|
||||
# "Custom allreduce is disabled because it's not supported on"
|
||||
# " more than two PCIe-only GPUs. To silence this warning, "
|
||||
# "specify disable_custom_all_reduce=True explicitly.")
|
||||
# return
|
||||
logger.warning(
|
||||
"We are using PCIe's custom allreduce."
|
||||
"If the performance is poor, we can add "
|
||||
"--disable-custom-all-reduce in the instruction.")
|
||||
# test P2P capability, this checks software/cudaruntime support
|
||||
# this is expensive to compute at the first time
|
||||
# then we cache the result
|
||||
# On AMD GPU, p2p is always enabled between XGMI connected GPUs
|
||||
if not _is_hip and not _can_p2p(rank, world_size):
|
||||
logger.warning(
|
||||
"Custom allreduce is disabled because your platform lacks "
|
||||
"GPU P2P capability or P2P test failed. To silence this "
|
||||
"warning, specify disable_custom_all_reduce=True explicitly.")
|
||||
return
|
||||
|
||||
self.disabled = False
|
||||
# Buffers memory are owned by this Python class and passed to C++.
|
||||
# Meta data composes of two parts: meta data for synchronization and a
|
||||
# temporary buffer for storing intermediate allreduce results.
|
||||
self.meta_ptrs = self.create_shared_buffer(ops.meta_size() + max_size,
|
||||
group=group,
|
||||
uncached=True)
|
||||
# This is a pre-registered IPC buffer. In eager mode, input tensors
|
||||
# are first copied into this buffer before allreduce is performed
|
||||
self.buffer_ptrs = self.create_shared_buffer(max_size, group=group)
|
||||
# This is a buffer for storing the tuples of pointers pointing to
|
||||
# IPC buffers from all ranks. Each registered tuple has size of
|
||||
# 8*world_size bytes where world_size is at most 8. Allocating 8MB
|
||||
# is enough for 131072 such tuples. The largest model I've seen only
|
||||
# needs less than 10000 of registered tuples.
|
||||
self.rank_data = torch.empty(8 * 1024 * 1024,
|
||||
dtype=torch.uint8,
|
||||
device=self.device)
|
||||
self.max_size = max_size
|
||||
self.rank = rank
|
||||
self.world_size = world_size
|
||||
self.fully_connected = fully_connected
|
||||
self._ptr = ops.init_custom_ar(self.meta_ptrs, self.rank_data, rank,
|
||||
self.fully_connected)
|
||||
ops.register_buffer(self._ptr, self.buffer_ptrs)
|
||||
|
||||
@contextmanager
|
||||
def capture(self):
|
||||
"""
|
||||
The main responsibility of this context manager is the
|
||||
`register_graph_buffers` call at the end of the context.
|
||||
It records all the buffer addresses used in the CUDA graph.
|
||||
"""
|
||||
try:
|
||||
self._IS_CAPTURING = True
|
||||
yield
|
||||
finally:
|
||||
self._IS_CAPTURING = False
|
||||
if not self.disabled:
|
||||
self.register_graph_buffers()
|
||||
|
||||
def register_graph_buffers(self):
|
||||
handle, offset = ops.get_graph_buffer_ipc_meta(self._ptr)
|
||||
logger.info("Registering %d cuda graph addresses", len(offset))
|
||||
# We cannot directly use `dist.all_gather_object` here
|
||||
# because it is incompatible with `gloo` backend under inference mode.
|
||||
# see https://github.com/pytorch/pytorch/issues/126032 for details.
|
||||
all_data = [[None, None]
|
||||
for _ in range(dist.get_world_size(group=self.group))]
|
||||
all_data[self.rank] = [handle, offset]
|
||||
ranks = sorted(dist.get_process_group_ranks(group=self.group))
|
||||
for i, rank in enumerate(ranks):
|
||||
dist.broadcast_object_list(all_data[i],
|
||||
src=rank,
|
||||
group=self.group,
|
||||
device="cpu")
|
||||
# Unpack list of tuples to tuple of lists.
|
||||
handles = [d[0] for d in all_data] # type: ignore
|
||||
offsets = [d[1] for d in all_data] # type: ignore
|
||||
ops.register_graph_buffers(self._ptr, handles, offsets)
|
||||
|
||||
def should_custom_ar(self, inp: torch.Tensor):
|
||||
if self.disabled:
|
||||
return False
|
||||
inp_size = inp.numel() * inp.element_size()
|
||||
# custom allreduce requires input byte size to be multiples of 16
|
||||
if inp_size % 16 != 0:
|
||||
return False
|
||||
if not is_weak_contiguous(inp):
|
||||
return False
|
||||
# for 4 or more non NVLink-capable GPUs, custom allreduce provides
|
||||
# little performance improvement over NCCL.
|
||||
return inp_size <= self.max_size
|
||||
|
||||
def all_reduce(self,
|
||||
inp: torch.Tensor,
|
||||
*,
|
||||
out: torch.Tensor = None,
|
||||
registered: bool = False):
|
||||
"""Performs an out-of-place all reduce.
|
||||
|
||||
If registered is True, this assumes inp's pointer is already
|
||||
IPC-registered. Otherwise, inp is first copied into a pre-registered
|
||||
buffer.
|
||||
"""
|
||||
if out is None:
|
||||
out = torch.empty_like(inp)
|
||||
if registered:
|
||||
ops.all_reduce(self._ptr, inp, out, 0, 0)
|
||||
else:
|
||||
ops.all_reduce(self._ptr, inp, out, self.buffer_ptrs[self.rank],
|
||||
self.max_size)
|
||||
return out
|
||||
|
||||
def custom_all_reduce(self, input: torch.Tensor) -> Optional[torch.Tensor]:
|
||||
"""The main allreduce API that provides support for cuda graph."""
|
||||
# When custom allreduce is disabled, this will be None.
|
||||
if self.disabled or not self.should_custom_ar(input):
|
||||
return None
|
||||
if self._IS_CAPTURING:
|
||||
if torch.cuda.is_current_stream_capturing():
|
||||
return self.all_reduce(input, registered=False)
|
||||
else:
|
||||
# If warm up, mimic the allocation pattern since custom
|
||||
# allreduce is out-of-place.
|
||||
return torch.empty_like(input)
|
||||
else:
|
||||
# Note: outside of cuda graph context, custom allreduce incurs a
|
||||
# cost of cudaMemcpy, which should be small (<=1% of overall
|
||||
# latency) compared to the performance gain of using custom kernels
|
||||
return self.all_reduce(input, registered=False)
|
||||
|
||||
def close(self):
|
||||
if not self.disabled and self._ptr:
|
||||
if ops is not None:
|
||||
ops.dispose(self._ptr)
|
||||
self._ptr = 0
|
||||
self.free_shared_buffer(self.meta_ptrs, rank=self.rank)
|
||||
self.free_shared_buffer(self.buffer_ptrs, rank=self.rank)
|
||||
|
||||
def __del__(self):
|
||||
self.close()
|
||||
|
||||
|
||||
@staticmethod
|
||||
def create_shared_buffer(size_in_bytes: int,
|
||||
group: Optional[ProcessGroup] = None,
|
||||
uncached: Optional[bool] = False) -> list[int]:
|
||||
pointer, handle = ops.allocate_shared_buffer_and_handle(size_in_bytes)
|
||||
|
||||
world_size = dist.get_world_size(group=group)
|
||||
rank = dist.get_rank(group=group)
|
||||
handles = [None] * world_size
|
||||
dist.all_gather_object(handles, handle, group=group)
|
||||
|
||||
pointers: list[int] = []
|
||||
for i, h in enumerate(handles):
|
||||
if i == rank:
|
||||
pointers.append(pointer) # type: ignore
|
||||
else:
|
||||
pointers.append(ops.open_mem_handle(h))
|
||||
return pointers
|
||||
|
||||
@staticmethod
|
||||
def free_shared_buffer(pointers: list[int],
|
||||
group: Optional[ProcessGroup] = None,
|
||||
rank: Optional[int] = 0) -> None:
|
||||
if rank is None:
|
||||
rank = dist.get_rank(group=group)
|
||||
if ops is not None:
|
||||
ops.free_shared_buffer(pointers[rank])
|
||||
|
||||
@@ -53,7 +53,6 @@ from sglang.srt.utils import (
|
||||
is_xpu,
|
||||
supports_custom_op,
|
||||
)
|
||||
from sglang.srt import _custom_ops as ops
|
||||
|
||||
_is_npu = is_npu()
|
||||
_is_cpu = is_cpu()
|
||||
@@ -304,7 +303,7 @@ class GroupCoordinator:
|
||||
|
||||
# Lazy import to avoid documentation build error
|
||||
from sglang.srt.distributed.device_communicators.custom_all_reduce import (
|
||||
CustomAllreduce, DCUCustomAllreduce
|
||||
CustomAllreduce,
|
||||
)
|
||||
from sglang.srt.distributed.device_communicators.pymscclpp import (
|
||||
PyMscclppCommunicator,
|
||||
@@ -348,17 +347,11 @@ class GroupCoordinator:
|
||||
else:
|
||||
ca_max_size = 8 * 1024 * 1024
|
||||
try:
|
||||
if is_hip() and ops.use_dcu_custom_allreduce:
|
||||
self.ca_comm = DCUCustomAllreduce(
|
||||
group=self.cpu_group,
|
||||
device=self.device,
|
||||
)
|
||||
else:
|
||||
self.ca_comm = CustomAllreduce(
|
||||
group=self.cpu_group,
|
||||
device=self.device,
|
||||
max_size=ca_max_size,
|
||||
)
|
||||
self.ca_comm = CustomAllreduce(
|
||||
group=self.cpu_group,
|
||||
device=self.device,
|
||||
max_size=ca_max_size,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Setup Custom allreduce failed with {e}. To silence this "
|
||||
|
||||
@@ -169,14 +169,6 @@ class RMSNorm(CustomOp):
|
||||
try:
|
||||
output = torch.empty_like(x)
|
||||
residual_out = torch.empty_like(x)
|
||||
fused_add_rms_norm(
|
||||
x,
|
||||
residual,
|
||||
self.weight.data,
|
||||
self.variance_epsilon,
|
||||
)
|
||||
return x, residual
|
||||
except TypeError:
|
||||
fused_add_rms_norm(
|
||||
output,
|
||||
x,
|
||||
@@ -186,7 +178,14 @@ class RMSNorm(CustomOp):
|
||||
self.variance_epsilon,
|
||||
)
|
||||
return output, residual_out
|
||||
|
||||
except TypeError:
|
||||
fused_add_rms_norm(
|
||||
x,
|
||||
residual,
|
||||
self.weight.data,
|
||||
self.variance_epsilon,
|
||||
)
|
||||
return x, residual
|
||||
|
||||
out = torch.empty_like(x)
|
||||
rms_norm(out, x, self.weight.data, self.variance_epsilon)
|
||||
|
||||
0
python/sglang/srt/layers/moe/ep_moe/layer.py
Normal file → Executable file
0
python/sglang/srt/layers/moe/ep_moe/layer.py
Normal file → Executable file
@@ -14,10 +14,9 @@ from sglang.srt.layers.quantization.fp8_kernel import (
|
||||
)
|
||||
from sglang.srt.layers.quantization.int8_kernel import (
|
||||
per_token_group_quant_int8,
|
||||
# per_token_quant_int8,
|
||||
per_token_quant_int8,
|
||||
sglang_per_token_group_quant_int8,
|
||||
)
|
||||
from lmslim.layers.gemm.int8_utils import per_token_quant_int8
|
||||
from sglang.srt.utils import (
|
||||
cpu_has_amx_support,
|
||||
get_bool_env_var,
|
||||
|
||||
0
python/sglang/srt/layers/moe/token_dispatcher/deepep.py
Normal file → Executable file
0
python/sglang/srt/layers/moe/token_dispatcher/deepep.py
Normal file → Executable file
@@ -57,7 +57,6 @@ from sglang.srt.layers.quantization.qoq import QoQConfig
|
||||
from sglang.srt.layers.quantization.w4afp8 import W4AFp8Config
|
||||
from sglang.srt.layers.quantization.w8a8_fp8 import W8A8Fp8Config
|
||||
from sglang.srt.layers.quantization.w8a8_int8 import W8A8Int8Config
|
||||
from sglang.srt.layers.quantization.slimquant_w4a8_marlin import SlimQuantW4A8Int8MarlinConfig
|
||||
from sglang.srt.utils import is_cuda, is_hip, mxfp_supported
|
||||
|
||||
_is_mxfp_supported = mxfp_supported()
|
||||
@@ -84,7 +83,6 @@ BASE_QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
|
||||
"w4afp8": W4AFp8Config,
|
||||
"petit_nvfp4": PetitNvFp4Config,
|
||||
"fbgemm_fp8": FBGEMMFp8Config,
|
||||
"slimquant_w4a8_marlin":SlimQuantW4A8Int8MarlinConfig,
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
# from sglang.srt.layers.moe.token_dispatcher.base import CombineInput
|
||||
|
||||
from sglang.srt.layers.moe.token_dispatcher.base import CombineInput
|
||||
from sglang.srt.layers.moe.token_dispatcher.standard import StandardCombineInput, StandardDispatchOutput
|
||||
import torch
|
||||
from sglang.srt import _custom_ops as ops
|
||||
from sglang.srt.utils import set_weight_attrs
|
||||
@@ -218,9 +218,8 @@ class SlimQuantW4A8Int8MarlinMoEMethod:
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
dispatch_output,
|
||||
) :
|
||||
from sglang.srt.layers.moe.token_dispatcher.standard import StandardCombineInput
|
||||
dispatch_output: StandardDispatchOutput,
|
||||
) -> CombineInput:
|
||||
x = dispatch_output.hidden_states
|
||||
topk_output = dispatch_output.topk_output
|
||||
from sglang.srt.layers.moe.topk import apply_topk_weights_cpu
|
||||
@@ -242,7 +241,7 @@ class SlimQuantW4A8Int8MarlinMoEMethod:
|
||||
use_int4_w4a8=True,
|
||||
per_channel_quant=True,
|
||||
activation=layer.moe_runner_config.activation,
|
||||
# expert_map=layer.expert_map_gpu,
|
||||
expert_map=layer.expert_map_gpu,
|
||||
apply_router_weight_on_input=self.moe_runner_config.apply_router_weight_on_input,
|
||||
global_num_experts=layer.moe_runner_config.num_experts,
|
||||
w1_scale=(layer.w13_weight_scale),
|
||||
|
||||
@@ -1,92 +0,0 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
try:
|
||||
from lightop import awq_marlin_repack_w4a8
|
||||
use_lightop = False
|
||||
except Exception:
|
||||
use_lightop = False
|
||||
|
||||
def unpack_int8_to_int4(tensor_int8: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
将[N, K//2]大小的torch.int8 Tensor,转换为[N, K]大小的torch.int32 Tensor。
|
||||
每个int8包含两个int4,分别提取到int32的低4位,其余位为0。
|
||||
|
||||
Args:
|
||||
tensor_int8 (torch.Tensor): 输入张量,形状为[N, K//2],类型为torch.int8。
|
||||
|
||||
Returns:
|
||||
torch.Tensor: 输出张量,形状为[N, K],类型为torch.int32。
|
||||
"""
|
||||
if tensor_int8.dtype != torch.int8:
|
||||
raise ValueError("Input tensor must be of type torch.int8")
|
||||
|
||||
N, K_half = tensor_int8.shape
|
||||
tensor_uint8 = tensor_int8.to(torch.uint8)
|
||||
high4 = tensor_uint8 & 0x0F
|
||||
low4 = (tensor_uint8 >> 4) & 0x0F
|
||||
unpacked = torch.empty((N, K_half * 2), dtype=torch.int32, device=tensor_int8.device)
|
||||
unpacked[:, 0::2] = low4.to(torch.int32)
|
||||
unpacked[:, 1::2] = high4.to(torch.int32)
|
||||
|
||||
return unpacked
|
||||
|
||||
def get_weight_perms(interleave: bool=True):
|
||||
perm = []
|
||||
for i in range(64):
|
||||
|
||||
for col in range(4):
|
||||
cur_col = (i % 16) * 4 + col
|
||||
for row in range(8):
|
||||
cur_row = (i // 16) * 8 + row
|
||||
cur_idx = cur_row * 64 + cur_col
|
||||
perm.append(cur_idx)
|
||||
|
||||
perm = np.array(perm)
|
||||
if interleave:
|
||||
interleave = np.array([4, 0, 5, 1, 6, 2, 7, 3])
|
||||
perm = perm.reshape((-1, 8))[:, interleave].ravel()
|
||||
|
||||
perm = torch.from_numpy(perm)
|
||||
|
||||
return perm
|
||||
|
||||
def marlin_weights(q_w,weight_perm,k_tile=32,n_tile=64,pack_factor=8):
|
||||
size_k, size_n = q_w.shape
|
||||
q_w = q_w.reshape((size_k // k_tile, k_tile, size_n // n_tile, n_tile))
|
||||
q_w = q_w.permute((0, 2, 1, 3))
|
||||
q_w = q_w.reshape((size_k // k_tile, size_n * k_tile))
|
||||
q_w = q_w.reshape((-1, weight_perm.numel()))[:, weight_perm].reshape(q_w.shape)
|
||||
|
||||
orig_device = q_w.device
|
||||
q_w = q_w.contiguous().to(torch.int32)
|
||||
M, N = q_w.shape
|
||||
assert N % pack_factor == 0, f"size_n ({N}) must be divisible by pack_factor ({pack_factor})"
|
||||
q_packed = torch.zeros((M, N // pack_factor), dtype=torch.int32, device=orig_device)
|
||||
for i in range(pack_factor):
|
||||
q_packed += q_w[:, i::pack_factor] << (4 * i)
|
||||
|
||||
return q_packed
|
||||
|
||||
def w4a8_2_marlin_weight(w4a8_w):
|
||||
full_w4a8_w = unpack_int8_to_int4(w4a8_w)
|
||||
full_w4a8_w = full_w4a8_w.T
|
||||
weight_perm = get_weight_perms()
|
||||
marlin_q_w = marlin_weights(full_w4a8_w, weight_perm, k_tile=32, n_tile=64, pack_factor=8)
|
||||
return marlin_q_w
|
||||
|
||||
def w4a8_weight_repack_impl(input):
|
||||
if use_lightop:
|
||||
size_batch = input.shape[0]
|
||||
size_n = input.shape[1]
|
||||
size_k = input.shape[2] * 2
|
||||
output = torch.zeros((size_batch, size_k // 32, size_n * 4), device=input.device, dtype=torch.int32)
|
||||
awq_marlin_repack_w4a8(input, output, size_batch, size_k, size_n)
|
||||
else:
|
||||
w_marlin_list = []
|
||||
for e in range(input.shape[0]):
|
||||
w_marlin_in = w4a8_2_marlin_weight(input[e])
|
||||
w_marlin_list.append(w_marlin_in)
|
||||
output = torch.stack(w_marlin_list, dim=0)
|
||||
|
||||
return output
|
||||
@@ -22,8 +22,7 @@ from sglang.srt.layers.quantization.base_config import (
|
||||
QuantizeMethodBase,
|
||||
)
|
||||
from sglang.srt.layers.quantization.compressed_tensors.utils import should_ignore_layer
|
||||
# from sglang.srt.layers.quantization.int8_kernel import per_token_quant_int8
|
||||
from lmslim.layers.gemm.int8_utils import per_token_quant_int8
|
||||
from sglang.srt.layers.quantization.int8_kernel import per_token_quant_int8
|
||||
from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod
|
||||
from sglang.srt.utils import (
|
||||
apply_module_patch,
|
||||
@@ -40,8 +39,6 @@ if TYPE_CHECKING:
|
||||
CombineInput,
|
||||
StandardDispatchOutput,
|
||||
)
|
||||
from lmslim import quant_ops
|
||||
|
||||
|
||||
_is_cuda = is_cuda()
|
||||
_is_cpu_amx_available = cpu_has_amx_support()
|
||||
@@ -408,7 +405,7 @@ class W8A8Int8LinearMethod(LinearMethodBase):
|
||||
x_scale_2d = x_scale.view(-1, x_scale.shape[-1])
|
||||
output_shape = [*x_q.shape[:-1], layer.weight.shape[1]]
|
||||
|
||||
output = quant_ops.triton_scaled_mm(
|
||||
output = int8_scaled_mm(
|
||||
x_q_2d,
|
||||
layer.weight,
|
||||
x_scale_2d,
|
||||
|
||||
@@ -1618,7 +1618,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
||||
self.seq_lens_cpu = self.seq_lens_cpu[keep_indices]
|
||||
self.orig_seq_lens = self.orig_seq_lens[keep_indices_device]
|
||||
self.out_cache_loc = None
|
||||
self.seq_lens_sum = self.seq_lens.sum()
|
||||
self.seq_lens_sum = self.seq_lens.sum().item()
|
||||
self.output_ids = self.output_ids[keep_indices_device]
|
||||
self.return_logprob = any(req.return_logprob for req in self.reqs)
|
||||
if self.return_logprob:
|
||||
|
||||
@@ -203,7 +203,7 @@ _is_xpu_xmx_available = xpu_has_xmx_support()
|
||||
SGLANG_CI_SMALL_KV_SIZE = os.getenv("SGLANG_CI_SMALL_KV_SIZE", None)
|
||||
|
||||
# Detect stragger ranks in model loading
|
||||
UNBALANCED_MODEL_LOADING_TIMEOUT_S = 3600
|
||||
UNBALANCED_MODEL_LOADING_TIMEOUT_S = 300
|
||||
|
||||
# the ratio of mamba cache pool size to max_running_requests, it will be safe when it is larger than 2 (yizhang2077)
|
||||
MAMBA_CACHE_SIZE_MAX_RUNNING_REQUESTS_RATIO = 3
|
||||
|
||||
@@ -1,58 +0,0 @@
|
||||
from ctypes import *
|
||||
import os
|
||||
import time
|
||||
import threading
|
||||
|
||||
class Prof:
|
||||
def __init__(self):
|
||||
self.use_roctx = os.getenv('SGLANG_HIP_PROF') is not None
|
||||
if self.use_roctx:
|
||||
self.lib = cdll.LoadLibrary("libroctracer64.so")
|
||||
self.lib.roctxRangePushA.argtypes = [c_char_p]
|
||||
self.lib.roctxRangePushA.restype = c_int
|
||||
self.lib.roctxRangePop.restype = c_int
|
||||
self.tm = time.perf_counter()
|
||||
self.push_depth = {}
|
||||
|
||||
def StartTracer(self):
|
||||
if self.use_roctx:
|
||||
if self.lib is None:
|
||||
self.lib = cdll.LoadLibrary("libroctracer64.so")
|
||||
self.lib.roctracer_start()
|
||||
self.roc_tracer_flag = True
|
||||
|
||||
def StopTracer(self):
|
||||
if self.use_roctx:
|
||||
if self.lib is None:
|
||||
self.lib = cdll.LoadLibrary("libroctracer64.so")
|
||||
self.lib.roctracer_stop()
|
||||
self.roc_tracer_flag = False
|
||||
|
||||
def thread_depth_add(self, num):
|
||||
current_thread = threading.current_thread()
|
||||
thread_id = current_thread.ident
|
||||
if thread_id not in self.push_depth.keys():
|
||||
self.push_depth[thread_id] = 0
|
||||
if num < 0 and self.push_depth[thread_id] == 0:
|
||||
return False
|
||||
self.push_depth[thread_id] += num
|
||||
return True
|
||||
|
||||
def ProfRangePush(self, message):
|
||||
if profile.use_roctx and self.roc_tracer_flag:
|
||||
profile.lib.roctxRangePushA(message.encode('utf-8'))
|
||||
profile.lib.roctxRangePushA(message.encode('utf-8'))
|
||||
self.thread_depth_add(1)
|
||||
|
||||
def ProfRangePop(self):
|
||||
if profile.use_roctx and self.roc_tracer_flag:
|
||||
if not self.thread_depth_add(-1):
|
||||
return
|
||||
profile.lib.roctxRangePop()
|
||||
|
||||
def ProfRangeAutoPush(self, message):
|
||||
self.ProfRangePop()
|
||||
self.ProfRangePush(message)
|
||||
|
||||
|
||||
profile = Prof()
|
||||
@@ -93,7 +93,6 @@ QUANTIZATION_CHOICES = [
|
||||
"w4afp8",
|
||||
"mxfp4",
|
||||
"compressed-tensors", # for Ktransformers
|
||||
"slimquant_w4a8_marlin",
|
||||
]
|
||||
|
||||
ATTENTION_BACKEND_CHOICES = [
|
||||
|
||||
@@ -25,7 +25,7 @@
|
||||
#define INTRIN_M 16
|
||||
#define INTRIN_N 16
|
||||
#define INTRIN_K 32
|
||||
#define WARP_SIZE 64
|
||||
#define WARP_SIZE 32
|
||||
#define SMEM_PAD_A 0
|
||||
#define SMEM_PAD_B 0
|
||||
#define PACK_SIZE 16
|
||||
|
||||
@@ -25,7 +25,7 @@
|
||||
#define INTRIN_M 16
|
||||
#define INTRIN_N 16
|
||||
#define INTRIN_K 32
|
||||
#define WARP_SIZE 64
|
||||
#define WARP_SIZE 32
|
||||
#define SMEM_PAD_A 0
|
||||
#define SMEM_PAD_B 0
|
||||
#define PACK_SIZE 16
|
||||
|
||||
@@ -5,7 +5,7 @@
|
||||
#include <cstdint>
|
||||
|
||||
#ifndef USE_ROCM
|
||||
#define WARP_SIZE 64
|
||||
#define WARP_SIZE 32
|
||||
#include "pytorch_extension_utils.h"
|
||||
#else
|
||||
#include "pytorch_extension_utils_rocm.h"
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
// copied from https://github.com/ggerganov/llama.cpp/blob/b2899/ggml-common.h
|
||||
#define QK_K 256
|
||||
#define K_QUANTS_PER_ITERATION 2
|
||||
#define WARP_SIZE_GGUF 64
|
||||
#define WARP_SIZE_GGUF 32
|
||||
#define K_SCALE_SIZE 12
|
||||
#define CUDA_DEQUANTIZE_BLOCK_SIZE 256
|
||||
#define CUDA_QUANTIZE_BLOCK_SIZE 256
|
||||
|
||||
@@ -340,7 +340,7 @@ inline bool getEnvEnablePDL() {
|
||||
#define CEILDIV(x, y) (((x) + (y) - 1) / (y))
|
||||
|
||||
#ifndef USE_ROCM
|
||||
#define WARP_SIZE 64
|
||||
#define WARP_SIZE 32
|
||||
#else
|
||||
#if defined(__GFX9__) || !defined(__HIP_DEVICE_COMPILE__)
|
||||
#define WARP_SIZE 64
|
||||
|
||||
Reference in New Issue
Block a user