Compare commits
13 Commits
v0.5.4
...
d2fdeac22f
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
d2fdeac22f | ||
|
|
75cd34d172 | ||
|
|
8fc552638f | ||
|
|
eb4ba1c295 | ||
|
|
4b9b337b39 | ||
|
|
f6528b74be | ||
|
|
a5718531b7 | ||
|
|
c333f12547 | ||
|
|
f9a026ad2b | ||
|
|
b80ae5e9ff | ||
|
|
b091a7a5c9 | ||
|
|
143ec5f36c | ||
|
|
67510e0172 |
@@ -839,10 +839,12 @@ class BenchmarkMetrics:
|
||||
mean_ttft_ms: float
|
||||
median_ttft_ms: float
|
||||
std_ttft_ms: float
|
||||
p95_ttft_ms: float
|
||||
p99_ttft_ms: float
|
||||
mean_tpot_ms: float
|
||||
median_tpot_ms: float
|
||||
std_tpot_ms: float
|
||||
p95_tpot_ms: float
|
||||
p99_tpot_ms: float
|
||||
mean_itl_ms: float
|
||||
median_itl_ms: float
|
||||
@@ -1665,10 +1667,12 @@ def calculate_metrics(
|
||||
* 1000, # ttfts is empty if streaming is not supported by backend
|
||||
median_ttft_ms=np.median(ttfts or 0) * 1000,
|
||||
std_ttft_ms=np.std(ttfts or 0) * 1000,
|
||||
p95_ttft_ms=np.percentile(ttfts or 0, 95) * 1000,
|
||||
p99_ttft_ms=np.percentile(ttfts or 0, 99) * 1000,
|
||||
mean_tpot_ms=np.mean(tpots or 0) * 1000,
|
||||
median_tpot_ms=np.median(tpots or 0) * 1000,
|
||||
std_tpot_ms=np.std(tpots or 0) * 1000,
|
||||
p95_tpot_ms=np.percentile(tpots or 0, 95) * 1000,
|
||||
p99_tpot_ms=np.percentile(tpots or 0, 99) * 1000,
|
||||
mean_itl_ms=np.mean(itls or 0) * 1000,
|
||||
median_itl_ms=np.median(itls or 0) * 1000,
|
||||
@@ -1974,6 +1978,12 @@ async def benchmark(
|
||||
print("{:<40} {:<10.2f}".format("Mean TTFT (ms):", metrics.mean_ttft_ms))
|
||||
print("{:<40} {:<10.2f}".format("Median TTFT (ms):", metrics.median_ttft_ms))
|
||||
print("{:<40} {:<10.2f}".format("P99 TTFT (ms):", metrics.p99_ttft_ms))
|
||||
print("{:<40} {:<10.2f}".format("P95 TTFT (ms):", metrics.p95_ttft_ms))
|
||||
print("{s:{c}^{n}}".format(s="Time per Output Token (excl. 1st token)", n=50, c="-"))
|
||||
print("{:<40} {:<10.2f}".format("Mean TPOT (ms):", metrics.mean_tpot_ms))
|
||||
print("{:<40} {:<10.2f}".format("Median TPOT (ms):", metrics.median_tpot_ms))
|
||||
print("{:<40} {:<10.2f}".format("P99 TPOT (ms):", metrics.p99_tpot_ms))
|
||||
print("{:<40} {:<10.2f}".format("P95 TPOT (ms):", metrics.p95_tpot_ms))
|
||||
print("{s:{c}^{n}}".format(s="Inter-Token Latency", n=50, c="-"))
|
||||
print("{:<40} {:<10.2f}".format("Mean ITL (ms):", metrics.mean_itl_ms))
|
||||
print("{:<40} {:<10.2f}".format("Median ITL (ms):", metrics.median_itl_ms))
|
||||
|
||||
@@ -22,9 +22,11 @@ use_vllm_custom_allreduce = get_bool_env_var(
|
||||
|
||||
if not is_hpu():
|
||||
# ROCm does not use vllm custom allreduce
|
||||
if use_vllm_custom_allreduce and not is_hip():
|
||||
# if use_vllm_custom_allreduce and not is_hip():
|
||||
if use_vllm_custom_allreduce:
|
||||
try:
|
||||
import vllm._C # noqa: F401
|
||||
print("[DEBUG] ✅ Using vLLM custom allreduce (vllm._C successfully imported)")
|
||||
except ImportError as e:
|
||||
logger.warning("Failed to import from vllm._C with %r", e)
|
||||
else:
|
||||
@@ -34,9 +36,11 @@ if not is_hpu():
|
||||
logger.warning("Failed to import from custom_ar with %r", e)
|
||||
|
||||
|
||||
if not is_hip() and not is_npu():
|
||||
# if not is_hip() and not is_npu():
|
||||
if not is_npu():
|
||||
if use_vllm_custom_allreduce:
|
||||
custom_op = torch.ops._C_custom_ar
|
||||
print("[DEBUG] ✅ custom_op = torch.ops._C_custom_ar (vLLM path active)")
|
||||
else:
|
||||
custom_op = sgl_kernel.allreduce
|
||||
|
||||
|
||||
@@ -614,6 +614,8 @@ class ModelConfig:
|
||||
"petit_nvfp4",
|
||||
"quark",
|
||||
"mxfp4",
|
||||
"slimquant_w4a8_marlin",
|
||||
"w8a8_int8",
|
||||
]
|
||||
optimized_quantization_methods = [
|
||||
"fp8",
|
||||
@@ -633,6 +635,7 @@ class ModelConfig:
|
||||
"qoq",
|
||||
"w4afp8",
|
||||
"petit_nvfp4",
|
||||
"slimquant_w4a8_marlin",
|
||||
]
|
||||
compatible_quantization_methods = {
|
||||
"modelopt_fp4": ["modelopt"],
|
||||
|
||||
@@ -27,7 +27,8 @@ _is_hip = is_hip()
|
||||
|
||||
|
||||
try:
|
||||
if ops.use_vllm_custom_allreduce and not _is_hip:
|
||||
# if ops.use_vllm_custom_allreduce and not _is_hip:
|
||||
if ops.use_vllm_custom_allreduce:
|
||||
# Use vLLM custom allreduce
|
||||
ops.meta_size()
|
||||
else:
|
||||
|
||||
@@ -169,6 +169,14 @@ class RMSNorm(CustomOp):
|
||||
try:
|
||||
output = torch.empty_like(x)
|
||||
residual_out = torch.empty_like(x)
|
||||
fused_add_rms_norm(
|
||||
x,
|
||||
residual,
|
||||
self.weight.data,
|
||||
self.variance_epsilon,
|
||||
)
|
||||
return x, residual
|
||||
except TypeError:
|
||||
fused_add_rms_norm(
|
||||
output,
|
||||
x,
|
||||
@@ -178,14 +186,7 @@ class RMSNorm(CustomOp):
|
||||
self.variance_epsilon,
|
||||
)
|
||||
return output, residual_out
|
||||
except TypeError:
|
||||
fused_add_rms_norm(
|
||||
x,
|
||||
residual,
|
||||
self.weight.data,
|
||||
self.variance_epsilon,
|
||||
)
|
||||
return x, residual
|
||||
|
||||
|
||||
out = torch.empty_like(x)
|
||||
rms_norm(out, x, self.weight.data, self.variance_epsilon)
|
||||
|
||||
0
python/sglang/srt/layers/moe/ep_moe/layer.py
Executable file → Normal file
0
python/sglang/srt/layers/moe/ep_moe/layer.py
Executable file → Normal file
@@ -14,9 +14,10 @@ from sglang.srt.layers.quantization.fp8_kernel import (
|
||||
)
|
||||
from sglang.srt.layers.quantization.int8_kernel import (
|
||||
per_token_group_quant_int8,
|
||||
per_token_quant_int8,
|
||||
# per_token_quant_int8,
|
||||
sglang_per_token_group_quant_int8,
|
||||
)
|
||||
from lmslim.layers.gemm.int8_utils import per_token_quant_int8
|
||||
from sglang.srt.utils import (
|
||||
cpu_has_amx_support,
|
||||
get_bool_env_var,
|
||||
|
||||
0
python/sglang/srt/layers/moe/token_dispatcher/deepep.py
Executable file → Normal file
0
python/sglang/srt/layers/moe/token_dispatcher/deepep.py
Executable file → Normal file
@@ -57,6 +57,7 @@ from sglang.srt.layers.quantization.qoq import QoQConfig
|
||||
from sglang.srt.layers.quantization.w4afp8 import W4AFp8Config
|
||||
from sglang.srt.layers.quantization.w8a8_fp8 import W8A8Fp8Config
|
||||
from sglang.srt.layers.quantization.w8a8_int8 import W8A8Int8Config
|
||||
from sglang.srt.layers.quantization.slimquant_w4a8_marlin import SlimQuantW4A8Int8MarlinConfig
|
||||
from sglang.srt.utils import is_cuda, is_hip, mxfp_supported
|
||||
|
||||
_is_mxfp_supported = mxfp_supported()
|
||||
@@ -83,6 +84,7 @@ BASE_QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
|
||||
"w4afp8": W4AFp8Config,
|
||||
"petit_nvfp4": PetitNvFp4Config,
|
||||
"fbgemm_fp8": FBGEMMFp8Config,
|
||||
"slimquant_w4a8_marlin":SlimQuantW4A8Int8MarlinConfig,
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
from sglang.srt.layers.moe.token_dispatcher.base import CombineInput
|
||||
from sglang.srt.layers.moe.token_dispatcher.standard import StandardCombineInput, StandardDispatchOutput
|
||||
# from sglang.srt.layers.moe.token_dispatcher.base import CombineInput
|
||||
|
||||
import torch
|
||||
from sglang.srt import _custom_ops as ops
|
||||
from sglang.srt.utils import set_weight_attrs
|
||||
@@ -218,8 +218,9 @@ class SlimQuantW4A8Int8MarlinMoEMethod:
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
dispatch_output: StandardDispatchOutput,
|
||||
) -> CombineInput:
|
||||
dispatch_output,
|
||||
) :
|
||||
from sglang.srt.layers.moe.token_dispatcher.standard import StandardCombineInput
|
||||
x = dispatch_output.hidden_states
|
||||
topk_output = dispatch_output.topk_output
|
||||
from sglang.srt.layers.moe.topk import apply_topk_weights_cpu
|
||||
@@ -241,7 +242,7 @@ class SlimQuantW4A8Int8MarlinMoEMethod:
|
||||
use_int4_w4a8=True,
|
||||
per_channel_quant=True,
|
||||
activation=layer.moe_runner_config.activation,
|
||||
expert_map=layer.expert_map_gpu,
|
||||
# expert_map=layer.expert_map_gpu,
|
||||
apply_router_weight_on_input=self.moe_runner_config.apply_router_weight_on_input,
|
||||
global_num_experts=layer.moe_runner_config.num_experts,
|
||||
w1_scale=(layer.w13_weight_scale),
|
||||
|
||||
92
python/sglang/srt/layers/quantization/w4a8_utils.py
Normal file
92
python/sglang/srt/layers/quantization/w4a8_utils.py
Normal file
@@ -0,0 +1,92 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
try:
|
||||
from lightop import awq_marlin_repack_w4a8
|
||||
use_lightop = False
|
||||
except Exception:
|
||||
use_lightop = False
|
||||
|
||||
def unpack_int8_to_int4(tensor_int8: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
将[N, K//2]大小的torch.int8 Tensor,转换为[N, K]大小的torch.int32 Tensor。
|
||||
每个int8包含两个int4,分别提取到int32的低4位,其余位为0。
|
||||
|
||||
Args:
|
||||
tensor_int8 (torch.Tensor): 输入张量,形状为[N, K//2],类型为torch.int8。
|
||||
|
||||
Returns:
|
||||
torch.Tensor: 输出张量,形状为[N, K],类型为torch.int32。
|
||||
"""
|
||||
if tensor_int8.dtype != torch.int8:
|
||||
raise ValueError("Input tensor must be of type torch.int8")
|
||||
|
||||
N, K_half = tensor_int8.shape
|
||||
tensor_uint8 = tensor_int8.to(torch.uint8)
|
||||
high4 = tensor_uint8 & 0x0F
|
||||
low4 = (tensor_uint8 >> 4) & 0x0F
|
||||
unpacked = torch.empty((N, K_half * 2), dtype=torch.int32, device=tensor_int8.device)
|
||||
unpacked[:, 0::2] = low4.to(torch.int32)
|
||||
unpacked[:, 1::2] = high4.to(torch.int32)
|
||||
|
||||
return unpacked
|
||||
|
||||
def get_weight_perms(interleave: bool=True):
|
||||
perm = []
|
||||
for i in range(64):
|
||||
|
||||
for col in range(4):
|
||||
cur_col = (i % 16) * 4 + col
|
||||
for row in range(8):
|
||||
cur_row = (i // 16) * 8 + row
|
||||
cur_idx = cur_row * 64 + cur_col
|
||||
perm.append(cur_idx)
|
||||
|
||||
perm = np.array(perm)
|
||||
if interleave:
|
||||
interleave = np.array([4, 0, 5, 1, 6, 2, 7, 3])
|
||||
perm = perm.reshape((-1, 8))[:, interleave].ravel()
|
||||
|
||||
perm = torch.from_numpy(perm)
|
||||
|
||||
return perm
|
||||
|
||||
def marlin_weights(q_w,weight_perm,k_tile=32,n_tile=64,pack_factor=8):
|
||||
size_k, size_n = q_w.shape
|
||||
q_w = q_w.reshape((size_k // k_tile, k_tile, size_n // n_tile, n_tile))
|
||||
q_w = q_w.permute((0, 2, 1, 3))
|
||||
q_w = q_w.reshape((size_k // k_tile, size_n * k_tile))
|
||||
q_w = q_w.reshape((-1, weight_perm.numel()))[:, weight_perm].reshape(q_w.shape)
|
||||
|
||||
orig_device = q_w.device
|
||||
q_w = q_w.contiguous().to(torch.int32)
|
||||
M, N = q_w.shape
|
||||
assert N % pack_factor == 0, f"size_n ({N}) must be divisible by pack_factor ({pack_factor})"
|
||||
q_packed = torch.zeros((M, N // pack_factor), dtype=torch.int32, device=orig_device)
|
||||
for i in range(pack_factor):
|
||||
q_packed += q_w[:, i::pack_factor] << (4 * i)
|
||||
|
||||
return q_packed
|
||||
|
||||
def w4a8_2_marlin_weight(w4a8_w):
|
||||
full_w4a8_w = unpack_int8_to_int4(w4a8_w)
|
||||
full_w4a8_w = full_w4a8_w.T
|
||||
weight_perm = get_weight_perms()
|
||||
marlin_q_w = marlin_weights(full_w4a8_w, weight_perm, k_tile=32, n_tile=64, pack_factor=8)
|
||||
return marlin_q_w
|
||||
|
||||
def w4a8_weight_repack_impl(input):
|
||||
if use_lightop:
|
||||
size_batch = input.shape[0]
|
||||
size_n = input.shape[1]
|
||||
size_k = input.shape[2] * 2
|
||||
output = torch.zeros((size_batch, size_k // 32, size_n * 4), device=input.device, dtype=torch.int32)
|
||||
awq_marlin_repack_w4a8(input, output, size_batch, size_k, size_n)
|
||||
else:
|
||||
w_marlin_list = []
|
||||
for e in range(input.shape[0]):
|
||||
w_marlin_in = w4a8_2_marlin_weight(input[e])
|
||||
w_marlin_list.append(w_marlin_in)
|
||||
output = torch.stack(w_marlin_list, dim=0)
|
||||
|
||||
return output
|
||||
@@ -22,7 +22,8 @@ from sglang.srt.layers.quantization.base_config import (
|
||||
QuantizeMethodBase,
|
||||
)
|
||||
from sglang.srt.layers.quantization.compressed_tensors.utils import should_ignore_layer
|
||||
from sglang.srt.layers.quantization.int8_kernel import per_token_quant_int8
|
||||
# from sglang.srt.layers.quantization.int8_kernel import per_token_quant_int8
|
||||
from lmslim.layers.gemm.int8_utils import per_token_quant_int8
|
||||
from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod
|
||||
from sglang.srt.utils import (
|
||||
apply_module_patch,
|
||||
@@ -39,6 +40,8 @@ if TYPE_CHECKING:
|
||||
CombineInput,
|
||||
StandardDispatchOutput,
|
||||
)
|
||||
from lmslim import quant_ops
|
||||
|
||||
|
||||
_is_cuda = is_cuda()
|
||||
_is_cpu_amx_available = cpu_has_amx_support()
|
||||
@@ -405,7 +408,7 @@ class W8A8Int8LinearMethod(LinearMethodBase):
|
||||
x_scale_2d = x_scale.view(-1, x_scale.shape[-1])
|
||||
output_shape = [*x_q.shape[:-1], layer.weight.shape[1]]
|
||||
|
||||
output = int8_scaled_mm(
|
||||
output = quant_ops.triton_scaled_mm(
|
||||
x_q_2d,
|
||||
layer.weight,
|
||||
x_scale_2d,
|
||||
|
||||
@@ -1618,7 +1618,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
||||
self.seq_lens_cpu = self.seq_lens_cpu[keep_indices]
|
||||
self.orig_seq_lens = self.orig_seq_lens[keep_indices_device]
|
||||
self.out_cache_loc = None
|
||||
self.seq_lens_sum = self.seq_lens.sum().item()
|
||||
self.seq_lens_sum = self.seq_lens.sum()
|
||||
self.output_ids = self.output_ids[keep_indices_device]
|
||||
self.return_logprob = any(req.return_logprob for req in self.reqs)
|
||||
if self.return_logprob:
|
||||
|
||||
@@ -203,7 +203,7 @@ _is_xpu_xmx_available = xpu_has_xmx_support()
|
||||
SGLANG_CI_SMALL_KV_SIZE = os.getenv("SGLANG_CI_SMALL_KV_SIZE", None)
|
||||
|
||||
# Detect stragger ranks in model loading
|
||||
UNBALANCED_MODEL_LOADING_TIMEOUT_S = 300
|
||||
UNBALANCED_MODEL_LOADING_TIMEOUT_S = 3600
|
||||
|
||||
# the ratio of mamba cache pool size to max_running_requests, it will be safe when it is larger than 2 (yizhang2077)
|
||||
MAMBA_CACHE_SIZE_MAX_RUNNING_REQUESTS_RATIO = 3
|
||||
|
||||
58
python/sglang/srt/profile/prof.py
Normal file
58
python/sglang/srt/profile/prof.py
Normal file
@@ -0,0 +1,58 @@
|
||||
from ctypes import *
|
||||
import os
|
||||
import time
|
||||
import threading
|
||||
|
||||
class Prof:
|
||||
def __init__(self):
|
||||
self.use_roctx = os.getenv('SGLANG_HIP_PROF') is not None
|
||||
if self.use_roctx:
|
||||
self.lib = cdll.LoadLibrary("libroctracer64.so")
|
||||
self.lib.roctxRangePushA.argtypes = [c_char_p]
|
||||
self.lib.roctxRangePushA.restype = c_int
|
||||
self.lib.roctxRangePop.restype = c_int
|
||||
self.tm = time.perf_counter()
|
||||
self.push_depth = {}
|
||||
|
||||
def StartTracer(self):
|
||||
if self.use_roctx:
|
||||
if self.lib is None:
|
||||
self.lib = cdll.LoadLibrary("libroctracer64.so")
|
||||
self.lib.roctracer_start()
|
||||
self.roc_tracer_flag = True
|
||||
|
||||
def StopTracer(self):
|
||||
if self.use_roctx:
|
||||
if self.lib is None:
|
||||
self.lib = cdll.LoadLibrary("libroctracer64.so")
|
||||
self.lib.roctracer_stop()
|
||||
self.roc_tracer_flag = False
|
||||
|
||||
def thread_depth_add(self, num):
|
||||
current_thread = threading.current_thread()
|
||||
thread_id = current_thread.ident
|
||||
if thread_id not in self.push_depth.keys():
|
||||
self.push_depth[thread_id] = 0
|
||||
if num < 0 and self.push_depth[thread_id] == 0:
|
||||
return False
|
||||
self.push_depth[thread_id] += num
|
||||
return True
|
||||
|
||||
def ProfRangePush(self, message):
|
||||
if profile.use_roctx and self.roc_tracer_flag:
|
||||
profile.lib.roctxRangePushA(message.encode('utf-8'))
|
||||
profile.lib.roctxRangePushA(message.encode('utf-8'))
|
||||
self.thread_depth_add(1)
|
||||
|
||||
def ProfRangePop(self):
|
||||
if profile.use_roctx and self.roc_tracer_flag:
|
||||
if not self.thread_depth_add(-1):
|
||||
return
|
||||
profile.lib.roctxRangePop()
|
||||
|
||||
def ProfRangeAutoPush(self, message):
|
||||
self.ProfRangePop()
|
||||
self.ProfRangePush(message)
|
||||
|
||||
|
||||
profile = Prof()
|
||||
@@ -93,6 +93,7 @@ QUANTIZATION_CHOICES = [
|
||||
"w4afp8",
|
||||
"mxfp4",
|
||||
"compressed-tensors", # for Ktransformers
|
||||
"slimquant_w4a8_marlin",
|
||||
]
|
||||
|
||||
ATTENTION_BACKEND_CHOICES = [
|
||||
|
||||
@@ -25,7 +25,7 @@
|
||||
#define INTRIN_M 16
|
||||
#define INTRIN_N 16
|
||||
#define INTRIN_K 32
|
||||
#define WARP_SIZE 32
|
||||
#define WARP_SIZE 64
|
||||
#define SMEM_PAD_A 0
|
||||
#define SMEM_PAD_B 0
|
||||
#define PACK_SIZE 16
|
||||
|
||||
@@ -25,7 +25,7 @@
|
||||
#define INTRIN_M 16
|
||||
#define INTRIN_N 16
|
||||
#define INTRIN_K 32
|
||||
#define WARP_SIZE 32
|
||||
#define WARP_SIZE 64
|
||||
#define SMEM_PAD_A 0
|
||||
#define SMEM_PAD_B 0
|
||||
#define PACK_SIZE 16
|
||||
|
||||
@@ -5,7 +5,7 @@
|
||||
#include <cstdint>
|
||||
|
||||
#ifndef USE_ROCM
|
||||
#define WARP_SIZE 32
|
||||
#define WARP_SIZE 64
|
||||
#include "pytorch_extension_utils.h"
|
||||
#else
|
||||
#include "pytorch_extension_utils_rocm.h"
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
// copied from https://github.com/ggerganov/llama.cpp/blob/b2899/ggml-common.h
|
||||
#define QK_K 256
|
||||
#define K_QUANTS_PER_ITERATION 2
|
||||
#define WARP_SIZE_GGUF 32
|
||||
#define WARP_SIZE_GGUF 64
|
||||
#define K_SCALE_SIZE 12
|
||||
#define CUDA_DEQUANTIZE_BLOCK_SIZE 256
|
||||
#define CUDA_QUANTIZE_BLOCK_SIZE 256
|
||||
|
||||
@@ -340,7 +340,7 @@ inline bool getEnvEnablePDL() {
|
||||
#define CEILDIV(x, y) (((x) + (y) - 1) / (y))
|
||||
|
||||
#ifndef USE_ROCM
|
||||
#define WARP_SIZE 32
|
||||
#define WARP_SIZE 64
|
||||
#else
|
||||
#if defined(__GFX9__) || !defined(__HIP_DEVICE_COMPILE__)
|
||||
#define WARP_SIZE 64
|
||||
|
||||
Reference in New Issue
Block a user