Compare commits
13 Commits
v0.5.4
...
d2fdeac22f
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
d2fdeac22f | ||
|
|
75cd34d172 | ||
|
|
8fc552638f | ||
|
|
eb4ba1c295 | ||
|
|
4b9b337b39 | ||
|
|
f6528b74be | ||
|
|
a5718531b7 | ||
|
|
c333f12547 | ||
|
|
f9a026ad2b | ||
|
|
b80ae5e9ff | ||
|
|
b091a7a5c9 | ||
|
|
143ec5f36c | ||
|
|
67510e0172 |
@@ -839,10 +839,12 @@ class BenchmarkMetrics:
|
|||||||
mean_ttft_ms: float
|
mean_ttft_ms: float
|
||||||
median_ttft_ms: float
|
median_ttft_ms: float
|
||||||
std_ttft_ms: float
|
std_ttft_ms: float
|
||||||
|
p95_ttft_ms: float
|
||||||
p99_ttft_ms: float
|
p99_ttft_ms: float
|
||||||
mean_tpot_ms: float
|
mean_tpot_ms: float
|
||||||
median_tpot_ms: float
|
median_tpot_ms: float
|
||||||
std_tpot_ms: float
|
std_tpot_ms: float
|
||||||
|
p95_tpot_ms: float
|
||||||
p99_tpot_ms: float
|
p99_tpot_ms: float
|
||||||
mean_itl_ms: float
|
mean_itl_ms: float
|
||||||
median_itl_ms: float
|
median_itl_ms: float
|
||||||
@@ -1665,10 +1667,12 @@ def calculate_metrics(
|
|||||||
* 1000, # ttfts is empty if streaming is not supported by backend
|
* 1000, # ttfts is empty if streaming is not supported by backend
|
||||||
median_ttft_ms=np.median(ttfts or 0) * 1000,
|
median_ttft_ms=np.median(ttfts or 0) * 1000,
|
||||||
std_ttft_ms=np.std(ttfts or 0) * 1000,
|
std_ttft_ms=np.std(ttfts or 0) * 1000,
|
||||||
|
p95_ttft_ms=np.percentile(ttfts or 0, 95) * 1000,
|
||||||
p99_ttft_ms=np.percentile(ttfts or 0, 99) * 1000,
|
p99_ttft_ms=np.percentile(ttfts or 0, 99) * 1000,
|
||||||
mean_tpot_ms=np.mean(tpots or 0) * 1000,
|
mean_tpot_ms=np.mean(tpots or 0) * 1000,
|
||||||
median_tpot_ms=np.median(tpots or 0) * 1000,
|
median_tpot_ms=np.median(tpots or 0) * 1000,
|
||||||
std_tpot_ms=np.std(tpots or 0) * 1000,
|
std_tpot_ms=np.std(tpots or 0) * 1000,
|
||||||
|
p95_tpot_ms=np.percentile(tpots or 0, 95) * 1000,
|
||||||
p99_tpot_ms=np.percentile(tpots or 0, 99) * 1000,
|
p99_tpot_ms=np.percentile(tpots or 0, 99) * 1000,
|
||||||
mean_itl_ms=np.mean(itls or 0) * 1000,
|
mean_itl_ms=np.mean(itls or 0) * 1000,
|
||||||
median_itl_ms=np.median(itls or 0) * 1000,
|
median_itl_ms=np.median(itls or 0) * 1000,
|
||||||
@@ -1974,6 +1978,12 @@ async def benchmark(
|
|||||||
print("{:<40} {:<10.2f}".format("Mean TTFT (ms):", metrics.mean_ttft_ms))
|
print("{:<40} {:<10.2f}".format("Mean TTFT (ms):", metrics.mean_ttft_ms))
|
||||||
print("{:<40} {:<10.2f}".format("Median TTFT (ms):", metrics.median_ttft_ms))
|
print("{:<40} {:<10.2f}".format("Median TTFT (ms):", metrics.median_ttft_ms))
|
||||||
print("{:<40} {:<10.2f}".format("P99 TTFT (ms):", metrics.p99_ttft_ms))
|
print("{:<40} {:<10.2f}".format("P99 TTFT (ms):", metrics.p99_ttft_ms))
|
||||||
|
print("{:<40} {:<10.2f}".format("P95 TTFT (ms):", metrics.p95_ttft_ms))
|
||||||
|
print("{s:{c}^{n}}".format(s="Time per Output Token (excl. 1st token)", n=50, c="-"))
|
||||||
|
print("{:<40} {:<10.2f}".format("Mean TPOT (ms):", metrics.mean_tpot_ms))
|
||||||
|
print("{:<40} {:<10.2f}".format("Median TPOT (ms):", metrics.median_tpot_ms))
|
||||||
|
print("{:<40} {:<10.2f}".format("P99 TPOT (ms):", metrics.p99_tpot_ms))
|
||||||
|
print("{:<40} {:<10.2f}".format("P95 TPOT (ms):", metrics.p95_tpot_ms))
|
||||||
print("{s:{c}^{n}}".format(s="Inter-Token Latency", n=50, c="-"))
|
print("{s:{c}^{n}}".format(s="Inter-Token Latency", n=50, c="-"))
|
||||||
print("{:<40} {:<10.2f}".format("Mean ITL (ms):", metrics.mean_itl_ms))
|
print("{:<40} {:<10.2f}".format("Mean ITL (ms):", metrics.mean_itl_ms))
|
||||||
print("{:<40} {:<10.2f}".format("Median ITL (ms):", metrics.median_itl_ms))
|
print("{:<40} {:<10.2f}".format("Median ITL (ms):", metrics.median_itl_ms))
|
||||||
|
|||||||
@@ -22,9 +22,11 @@ use_vllm_custom_allreduce = get_bool_env_var(
|
|||||||
|
|
||||||
if not is_hpu():
|
if not is_hpu():
|
||||||
# ROCm does not use vllm custom allreduce
|
# ROCm does not use vllm custom allreduce
|
||||||
if use_vllm_custom_allreduce and not is_hip():
|
# if use_vllm_custom_allreduce and not is_hip():
|
||||||
|
if use_vllm_custom_allreduce:
|
||||||
try:
|
try:
|
||||||
import vllm._C # noqa: F401
|
import vllm._C # noqa: F401
|
||||||
|
print("[DEBUG] ✅ Using vLLM custom allreduce (vllm._C successfully imported)")
|
||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
logger.warning("Failed to import from vllm._C with %r", e)
|
logger.warning("Failed to import from vllm._C with %r", e)
|
||||||
else:
|
else:
|
||||||
@@ -34,9 +36,11 @@ if not is_hpu():
|
|||||||
logger.warning("Failed to import from custom_ar with %r", e)
|
logger.warning("Failed to import from custom_ar with %r", e)
|
||||||
|
|
||||||
|
|
||||||
if not is_hip() and not is_npu():
|
# if not is_hip() and not is_npu():
|
||||||
|
if not is_npu():
|
||||||
if use_vllm_custom_allreduce:
|
if use_vllm_custom_allreduce:
|
||||||
custom_op = torch.ops._C_custom_ar
|
custom_op = torch.ops._C_custom_ar
|
||||||
|
print("[DEBUG] ✅ custom_op = torch.ops._C_custom_ar (vLLM path active)")
|
||||||
else:
|
else:
|
||||||
custom_op = sgl_kernel.allreduce
|
custom_op = sgl_kernel.allreduce
|
||||||
|
|
||||||
|
|||||||
@@ -614,6 +614,8 @@ class ModelConfig:
|
|||||||
"petit_nvfp4",
|
"petit_nvfp4",
|
||||||
"quark",
|
"quark",
|
||||||
"mxfp4",
|
"mxfp4",
|
||||||
|
"slimquant_w4a8_marlin",
|
||||||
|
"w8a8_int8",
|
||||||
]
|
]
|
||||||
optimized_quantization_methods = [
|
optimized_quantization_methods = [
|
||||||
"fp8",
|
"fp8",
|
||||||
@@ -633,6 +635,7 @@ class ModelConfig:
|
|||||||
"qoq",
|
"qoq",
|
||||||
"w4afp8",
|
"w4afp8",
|
||||||
"petit_nvfp4",
|
"petit_nvfp4",
|
||||||
|
"slimquant_w4a8_marlin",
|
||||||
]
|
]
|
||||||
compatible_quantization_methods = {
|
compatible_quantization_methods = {
|
||||||
"modelopt_fp4": ["modelopt"],
|
"modelopt_fp4": ["modelopt"],
|
||||||
|
|||||||
@@ -27,7 +27,8 @@ _is_hip = is_hip()
|
|||||||
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if ops.use_vllm_custom_allreduce and not _is_hip:
|
# if ops.use_vllm_custom_allreduce and not _is_hip:
|
||||||
|
if ops.use_vllm_custom_allreduce:
|
||||||
# Use vLLM custom allreduce
|
# Use vLLM custom allreduce
|
||||||
ops.meta_size()
|
ops.meta_size()
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -169,6 +169,14 @@ class RMSNorm(CustomOp):
|
|||||||
try:
|
try:
|
||||||
output = torch.empty_like(x)
|
output = torch.empty_like(x)
|
||||||
residual_out = torch.empty_like(x)
|
residual_out = torch.empty_like(x)
|
||||||
|
fused_add_rms_norm(
|
||||||
|
x,
|
||||||
|
residual,
|
||||||
|
self.weight.data,
|
||||||
|
self.variance_epsilon,
|
||||||
|
)
|
||||||
|
return x, residual
|
||||||
|
except TypeError:
|
||||||
fused_add_rms_norm(
|
fused_add_rms_norm(
|
||||||
output,
|
output,
|
||||||
x,
|
x,
|
||||||
@@ -178,14 +186,7 @@ class RMSNorm(CustomOp):
|
|||||||
self.variance_epsilon,
|
self.variance_epsilon,
|
||||||
)
|
)
|
||||||
return output, residual_out
|
return output, residual_out
|
||||||
except TypeError:
|
|
||||||
fused_add_rms_norm(
|
|
||||||
x,
|
|
||||||
residual,
|
|
||||||
self.weight.data,
|
|
||||||
self.variance_epsilon,
|
|
||||||
)
|
|
||||||
return x, residual
|
|
||||||
|
|
||||||
out = torch.empty_like(x)
|
out = torch.empty_like(x)
|
||||||
rms_norm(out, x, self.weight.data, self.variance_epsilon)
|
rms_norm(out, x, self.weight.data, self.variance_epsilon)
|
||||||
|
|||||||
0
python/sglang/srt/layers/moe/ep_moe/layer.py
Executable file → Normal file
0
python/sglang/srt/layers/moe/ep_moe/layer.py
Executable file → Normal file
@@ -14,9 +14,10 @@ from sglang.srt.layers.quantization.fp8_kernel import (
|
|||||||
)
|
)
|
||||||
from sglang.srt.layers.quantization.int8_kernel import (
|
from sglang.srt.layers.quantization.int8_kernel import (
|
||||||
per_token_group_quant_int8,
|
per_token_group_quant_int8,
|
||||||
per_token_quant_int8,
|
# per_token_quant_int8,
|
||||||
sglang_per_token_group_quant_int8,
|
sglang_per_token_group_quant_int8,
|
||||||
)
|
)
|
||||||
|
from lmslim.layers.gemm.int8_utils import per_token_quant_int8
|
||||||
from sglang.srt.utils import (
|
from sglang.srt.utils import (
|
||||||
cpu_has_amx_support,
|
cpu_has_amx_support,
|
||||||
get_bool_env_var,
|
get_bool_env_var,
|
||||||
|
|||||||
0
python/sglang/srt/layers/moe/token_dispatcher/deepep.py
Executable file → Normal file
0
python/sglang/srt/layers/moe/token_dispatcher/deepep.py
Executable file → Normal file
@@ -57,6 +57,7 @@ from sglang.srt.layers.quantization.qoq import QoQConfig
|
|||||||
from sglang.srt.layers.quantization.w4afp8 import W4AFp8Config
|
from sglang.srt.layers.quantization.w4afp8 import W4AFp8Config
|
||||||
from sglang.srt.layers.quantization.w8a8_fp8 import W8A8Fp8Config
|
from sglang.srt.layers.quantization.w8a8_fp8 import W8A8Fp8Config
|
||||||
from sglang.srt.layers.quantization.w8a8_int8 import W8A8Int8Config
|
from sglang.srt.layers.quantization.w8a8_int8 import W8A8Int8Config
|
||||||
|
from sglang.srt.layers.quantization.slimquant_w4a8_marlin import SlimQuantW4A8Int8MarlinConfig
|
||||||
from sglang.srt.utils import is_cuda, is_hip, mxfp_supported
|
from sglang.srt.utils import is_cuda, is_hip, mxfp_supported
|
||||||
|
|
||||||
_is_mxfp_supported = mxfp_supported()
|
_is_mxfp_supported = mxfp_supported()
|
||||||
@@ -83,6 +84,7 @@ BASE_QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
|
|||||||
"w4afp8": W4AFp8Config,
|
"w4afp8": W4AFp8Config,
|
||||||
"petit_nvfp4": PetitNvFp4Config,
|
"petit_nvfp4": PetitNvFp4Config,
|
||||||
"fbgemm_fp8": FBGEMMFp8Config,
|
"fbgemm_fp8": FBGEMMFp8Config,
|
||||||
|
"slimquant_w4a8_marlin":SlimQuantW4A8Int8MarlinConfig,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
from typing import Any, Callable, Dict, List, Optional
|
from typing import Any, Callable, Dict, List, Optional
|
||||||
from sglang.srt.layers.moe.token_dispatcher.base import CombineInput
|
# from sglang.srt.layers.moe.token_dispatcher.base import CombineInput
|
||||||
from sglang.srt.layers.moe.token_dispatcher.standard import StandardCombineInput, StandardDispatchOutput
|
|
||||||
import torch
|
import torch
|
||||||
from sglang.srt import _custom_ops as ops
|
from sglang.srt import _custom_ops as ops
|
||||||
from sglang.srt.utils import set_weight_attrs
|
from sglang.srt.utils import set_weight_attrs
|
||||||
@@ -218,8 +218,9 @@ class SlimQuantW4A8Int8MarlinMoEMethod:
|
|||||||
def apply(
|
def apply(
|
||||||
self,
|
self,
|
||||||
layer: torch.nn.Module,
|
layer: torch.nn.Module,
|
||||||
dispatch_output: StandardDispatchOutput,
|
dispatch_output,
|
||||||
) -> CombineInput:
|
) :
|
||||||
|
from sglang.srt.layers.moe.token_dispatcher.standard import StandardCombineInput
|
||||||
x = dispatch_output.hidden_states
|
x = dispatch_output.hidden_states
|
||||||
topk_output = dispatch_output.topk_output
|
topk_output = dispatch_output.topk_output
|
||||||
from sglang.srt.layers.moe.topk import apply_topk_weights_cpu
|
from sglang.srt.layers.moe.topk import apply_topk_weights_cpu
|
||||||
@@ -241,7 +242,7 @@ class SlimQuantW4A8Int8MarlinMoEMethod:
|
|||||||
use_int4_w4a8=True,
|
use_int4_w4a8=True,
|
||||||
per_channel_quant=True,
|
per_channel_quant=True,
|
||||||
activation=layer.moe_runner_config.activation,
|
activation=layer.moe_runner_config.activation,
|
||||||
expert_map=layer.expert_map_gpu,
|
# expert_map=layer.expert_map_gpu,
|
||||||
apply_router_weight_on_input=self.moe_runner_config.apply_router_weight_on_input,
|
apply_router_weight_on_input=self.moe_runner_config.apply_router_weight_on_input,
|
||||||
global_num_experts=layer.moe_runner_config.num_experts,
|
global_num_experts=layer.moe_runner_config.num_experts,
|
||||||
w1_scale=(layer.w13_weight_scale),
|
w1_scale=(layer.w13_weight_scale),
|
||||||
|
|||||||
92
python/sglang/srt/layers/quantization/w4a8_utils.py
Normal file
92
python/sglang/srt/layers/quantization/w4a8_utils.py
Normal file
@@ -0,0 +1,92 @@
|
|||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
try:
|
||||||
|
from lightop import awq_marlin_repack_w4a8
|
||||||
|
use_lightop = False
|
||||||
|
except Exception:
|
||||||
|
use_lightop = False
|
||||||
|
|
||||||
|
def unpack_int8_to_int4(tensor_int8: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
将[N, K//2]大小的torch.int8 Tensor,转换为[N, K]大小的torch.int32 Tensor。
|
||||||
|
每个int8包含两个int4,分别提取到int32的低4位,其余位为0。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tensor_int8 (torch.Tensor): 输入张量,形状为[N, K//2],类型为torch.int8。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: 输出张量,形状为[N, K],类型为torch.int32。
|
||||||
|
"""
|
||||||
|
if tensor_int8.dtype != torch.int8:
|
||||||
|
raise ValueError("Input tensor must be of type torch.int8")
|
||||||
|
|
||||||
|
N, K_half = tensor_int8.shape
|
||||||
|
tensor_uint8 = tensor_int8.to(torch.uint8)
|
||||||
|
high4 = tensor_uint8 & 0x0F
|
||||||
|
low4 = (tensor_uint8 >> 4) & 0x0F
|
||||||
|
unpacked = torch.empty((N, K_half * 2), dtype=torch.int32, device=tensor_int8.device)
|
||||||
|
unpacked[:, 0::2] = low4.to(torch.int32)
|
||||||
|
unpacked[:, 1::2] = high4.to(torch.int32)
|
||||||
|
|
||||||
|
return unpacked
|
||||||
|
|
||||||
|
def get_weight_perms(interleave: bool=True):
|
||||||
|
perm = []
|
||||||
|
for i in range(64):
|
||||||
|
|
||||||
|
for col in range(4):
|
||||||
|
cur_col = (i % 16) * 4 + col
|
||||||
|
for row in range(8):
|
||||||
|
cur_row = (i // 16) * 8 + row
|
||||||
|
cur_idx = cur_row * 64 + cur_col
|
||||||
|
perm.append(cur_idx)
|
||||||
|
|
||||||
|
perm = np.array(perm)
|
||||||
|
if interleave:
|
||||||
|
interleave = np.array([4, 0, 5, 1, 6, 2, 7, 3])
|
||||||
|
perm = perm.reshape((-1, 8))[:, interleave].ravel()
|
||||||
|
|
||||||
|
perm = torch.from_numpy(perm)
|
||||||
|
|
||||||
|
return perm
|
||||||
|
|
||||||
|
def marlin_weights(q_w,weight_perm,k_tile=32,n_tile=64,pack_factor=8):
|
||||||
|
size_k, size_n = q_w.shape
|
||||||
|
q_w = q_w.reshape((size_k // k_tile, k_tile, size_n // n_tile, n_tile))
|
||||||
|
q_w = q_w.permute((0, 2, 1, 3))
|
||||||
|
q_w = q_w.reshape((size_k // k_tile, size_n * k_tile))
|
||||||
|
q_w = q_w.reshape((-1, weight_perm.numel()))[:, weight_perm].reshape(q_w.shape)
|
||||||
|
|
||||||
|
orig_device = q_w.device
|
||||||
|
q_w = q_w.contiguous().to(torch.int32)
|
||||||
|
M, N = q_w.shape
|
||||||
|
assert N % pack_factor == 0, f"size_n ({N}) must be divisible by pack_factor ({pack_factor})"
|
||||||
|
q_packed = torch.zeros((M, N // pack_factor), dtype=torch.int32, device=orig_device)
|
||||||
|
for i in range(pack_factor):
|
||||||
|
q_packed += q_w[:, i::pack_factor] << (4 * i)
|
||||||
|
|
||||||
|
return q_packed
|
||||||
|
|
||||||
|
def w4a8_2_marlin_weight(w4a8_w):
|
||||||
|
full_w4a8_w = unpack_int8_to_int4(w4a8_w)
|
||||||
|
full_w4a8_w = full_w4a8_w.T
|
||||||
|
weight_perm = get_weight_perms()
|
||||||
|
marlin_q_w = marlin_weights(full_w4a8_w, weight_perm, k_tile=32, n_tile=64, pack_factor=8)
|
||||||
|
return marlin_q_w
|
||||||
|
|
||||||
|
def w4a8_weight_repack_impl(input):
|
||||||
|
if use_lightop:
|
||||||
|
size_batch = input.shape[0]
|
||||||
|
size_n = input.shape[1]
|
||||||
|
size_k = input.shape[2] * 2
|
||||||
|
output = torch.zeros((size_batch, size_k // 32, size_n * 4), device=input.device, dtype=torch.int32)
|
||||||
|
awq_marlin_repack_w4a8(input, output, size_batch, size_k, size_n)
|
||||||
|
else:
|
||||||
|
w_marlin_list = []
|
||||||
|
for e in range(input.shape[0]):
|
||||||
|
w_marlin_in = w4a8_2_marlin_weight(input[e])
|
||||||
|
w_marlin_list.append(w_marlin_in)
|
||||||
|
output = torch.stack(w_marlin_list, dim=0)
|
||||||
|
|
||||||
|
return output
|
||||||
@@ -22,7 +22,8 @@ from sglang.srt.layers.quantization.base_config import (
|
|||||||
QuantizeMethodBase,
|
QuantizeMethodBase,
|
||||||
)
|
)
|
||||||
from sglang.srt.layers.quantization.compressed_tensors.utils import should_ignore_layer
|
from sglang.srt.layers.quantization.compressed_tensors.utils import should_ignore_layer
|
||||||
from sglang.srt.layers.quantization.int8_kernel import per_token_quant_int8
|
# from sglang.srt.layers.quantization.int8_kernel import per_token_quant_int8
|
||||||
|
from lmslim.layers.gemm.int8_utils import per_token_quant_int8
|
||||||
from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod
|
from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod
|
||||||
from sglang.srt.utils import (
|
from sglang.srt.utils import (
|
||||||
apply_module_patch,
|
apply_module_patch,
|
||||||
@@ -39,6 +40,8 @@ if TYPE_CHECKING:
|
|||||||
CombineInput,
|
CombineInput,
|
||||||
StandardDispatchOutput,
|
StandardDispatchOutput,
|
||||||
)
|
)
|
||||||
|
from lmslim import quant_ops
|
||||||
|
|
||||||
|
|
||||||
_is_cuda = is_cuda()
|
_is_cuda = is_cuda()
|
||||||
_is_cpu_amx_available = cpu_has_amx_support()
|
_is_cpu_amx_available = cpu_has_amx_support()
|
||||||
@@ -405,7 +408,7 @@ class W8A8Int8LinearMethod(LinearMethodBase):
|
|||||||
x_scale_2d = x_scale.view(-1, x_scale.shape[-1])
|
x_scale_2d = x_scale.view(-1, x_scale.shape[-1])
|
||||||
output_shape = [*x_q.shape[:-1], layer.weight.shape[1]]
|
output_shape = [*x_q.shape[:-1], layer.weight.shape[1]]
|
||||||
|
|
||||||
output = int8_scaled_mm(
|
output = quant_ops.triton_scaled_mm(
|
||||||
x_q_2d,
|
x_q_2d,
|
||||||
layer.weight,
|
layer.weight,
|
||||||
x_scale_2d,
|
x_scale_2d,
|
||||||
|
|||||||
@@ -1618,7 +1618,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|||||||
self.seq_lens_cpu = self.seq_lens_cpu[keep_indices]
|
self.seq_lens_cpu = self.seq_lens_cpu[keep_indices]
|
||||||
self.orig_seq_lens = self.orig_seq_lens[keep_indices_device]
|
self.orig_seq_lens = self.orig_seq_lens[keep_indices_device]
|
||||||
self.out_cache_loc = None
|
self.out_cache_loc = None
|
||||||
self.seq_lens_sum = self.seq_lens.sum().item()
|
self.seq_lens_sum = self.seq_lens.sum()
|
||||||
self.output_ids = self.output_ids[keep_indices_device]
|
self.output_ids = self.output_ids[keep_indices_device]
|
||||||
self.return_logprob = any(req.return_logprob for req in self.reqs)
|
self.return_logprob = any(req.return_logprob for req in self.reqs)
|
||||||
if self.return_logprob:
|
if self.return_logprob:
|
||||||
|
|||||||
@@ -203,7 +203,7 @@ _is_xpu_xmx_available = xpu_has_xmx_support()
|
|||||||
SGLANG_CI_SMALL_KV_SIZE = os.getenv("SGLANG_CI_SMALL_KV_SIZE", None)
|
SGLANG_CI_SMALL_KV_SIZE = os.getenv("SGLANG_CI_SMALL_KV_SIZE", None)
|
||||||
|
|
||||||
# Detect stragger ranks in model loading
|
# Detect stragger ranks in model loading
|
||||||
UNBALANCED_MODEL_LOADING_TIMEOUT_S = 300
|
UNBALANCED_MODEL_LOADING_TIMEOUT_S = 3600
|
||||||
|
|
||||||
# the ratio of mamba cache pool size to max_running_requests, it will be safe when it is larger than 2 (yizhang2077)
|
# the ratio of mamba cache pool size to max_running_requests, it will be safe when it is larger than 2 (yizhang2077)
|
||||||
MAMBA_CACHE_SIZE_MAX_RUNNING_REQUESTS_RATIO = 3
|
MAMBA_CACHE_SIZE_MAX_RUNNING_REQUESTS_RATIO = 3
|
||||||
|
|||||||
58
python/sglang/srt/profile/prof.py
Normal file
58
python/sglang/srt/profile/prof.py
Normal file
@@ -0,0 +1,58 @@
|
|||||||
|
from ctypes import *
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
import threading
|
||||||
|
|
||||||
|
class Prof:
|
||||||
|
def __init__(self):
|
||||||
|
self.use_roctx = os.getenv('SGLANG_HIP_PROF') is not None
|
||||||
|
if self.use_roctx:
|
||||||
|
self.lib = cdll.LoadLibrary("libroctracer64.so")
|
||||||
|
self.lib.roctxRangePushA.argtypes = [c_char_p]
|
||||||
|
self.lib.roctxRangePushA.restype = c_int
|
||||||
|
self.lib.roctxRangePop.restype = c_int
|
||||||
|
self.tm = time.perf_counter()
|
||||||
|
self.push_depth = {}
|
||||||
|
|
||||||
|
def StartTracer(self):
|
||||||
|
if self.use_roctx:
|
||||||
|
if self.lib is None:
|
||||||
|
self.lib = cdll.LoadLibrary("libroctracer64.so")
|
||||||
|
self.lib.roctracer_start()
|
||||||
|
self.roc_tracer_flag = True
|
||||||
|
|
||||||
|
def StopTracer(self):
|
||||||
|
if self.use_roctx:
|
||||||
|
if self.lib is None:
|
||||||
|
self.lib = cdll.LoadLibrary("libroctracer64.so")
|
||||||
|
self.lib.roctracer_stop()
|
||||||
|
self.roc_tracer_flag = False
|
||||||
|
|
||||||
|
def thread_depth_add(self, num):
|
||||||
|
current_thread = threading.current_thread()
|
||||||
|
thread_id = current_thread.ident
|
||||||
|
if thread_id not in self.push_depth.keys():
|
||||||
|
self.push_depth[thread_id] = 0
|
||||||
|
if num < 0 and self.push_depth[thread_id] == 0:
|
||||||
|
return False
|
||||||
|
self.push_depth[thread_id] += num
|
||||||
|
return True
|
||||||
|
|
||||||
|
def ProfRangePush(self, message):
|
||||||
|
if profile.use_roctx and self.roc_tracer_flag:
|
||||||
|
profile.lib.roctxRangePushA(message.encode('utf-8'))
|
||||||
|
profile.lib.roctxRangePushA(message.encode('utf-8'))
|
||||||
|
self.thread_depth_add(1)
|
||||||
|
|
||||||
|
def ProfRangePop(self):
|
||||||
|
if profile.use_roctx and self.roc_tracer_flag:
|
||||||
|
if not self.thread_depth_add(-1):
|
||||||
|
return
|
||||||
|
profile.lib.roctxRangePop()
|
||||||
|
|
||||||
|
def ProfRangeAutoPush(self, message):
|
||||||
|
self.ProfRangePop()
|
||||||
|
self.ProfRangePush(message)
|
||||||
|
|
||||||
|
|
||||||
|
profile = Prof()
|
||||||
@@ -93,6 +93,7 @@ QUANTIZATION_CHOICES = [
|
|||||||
"w4afp8",
|
"w4afp8",
|
||||||
"mxfp4",
|
"mxfp4",
|
||||||
"compressed-tensors", # for Ktransformers
|
"compressed-tensors", # for Ktransformers
|
||||||
|
"slimquant_w4a8_marlin",
|
||||||
]
|
]
|
||||||
|
|
||||||
ATTENTION_BACKEND_CHOICES = [
|
ATTENTION_BACKEND_CHOICES = [
|
||||||
|
|||||||
@@ -25,7 +25,7 @@
|
|||||||
#define INTRIN_M 16
|
#define INTRIN_M 16
|
||||||
#define INTRIN_N 16
|
#define INTRIN_N 16
|
||||||
#define INTRIN_K 32
|
#define INTRIN_K 32
|
||||||
#define WARP_SIZE 32
|
#define WARP_SIZE 64
|
||||||
#define SMEM_PAD_A 0
|
#define SMEM_PAD_A 0
|
||||||
#define SMEM_PAD_B 0
|
#define SMEM_PAD_B 0
|
||||||
#define PACK_SIZE 16
|
#define PACK_SIZE 16
|
||||||
|
|||||||
@@ -25,7 +25,7 @@
|
|||||||
#define INTRIN_M 16
|
#define INTRIN_M 16
|
||||||
#define INTRIN_N 16
|
#define INTRIN_N 16
|
||||||
#define INTRIN_K 32
|
#define INTRIN_K 32
|
||||||
#define WARP_SIZE 32
|
#define WARP_SIZE 64
|
||||||
#define SMEM_PAD_A 0
|
#define SMEM_PAD_A 0
|
||||||
#define SMEM_PAD_B 0
|
#define SMEM_PAD_B 0
|
||||||
#define PACK_SIZE 16
|
#define PACK_SIZE 16
|
||||||
|
|||||||
@@ -5,7 +5,7 @@
|
|||||||
#include <cstdint>
|
#include <cstdint>
|
||||||
|
|
||||||
#ifndef USE_ROCM
|
#ifndef USE_ROCM
|
||||||
#define WARP_SIZE 32
|
#define WARP_SIZE 64
|
||||||
#include "pytorch_extension_utils.h"
|
#include "pytorch_extension_utils.h"
|
||||||
#else
|
#else
|
||||||
#include "pytorch_extension_utils_rocm.h"
|
#include "pytorch_extension_utils_rocm.h"
|
||||||
|
|||||||
@@ -3,7 +3,7 @@
|
|||||||
// copied from https://github.com/ggerganov/llama.cpp/blob/b2899/ggml-common.h
|
// copied from https://github.com/ggerganov/llama.cpp/blob/b2899/ggml-common.h
|
||||||
#define QK_K 256
|
#define QK_K 256
|
||||||
#define K_QUANTS_PER_ITERATION 2
|
#define K_QUANTS_PER_ITERATION 2
|
||||||
#define WARP_SIZE_GGUF 32
|
#define WARP_SIZE_GGUF 64
|
||||||
#define K_SCALE_SIZE 12
|
#define K_SCALE_SIZE 12
|
||||||
#define CUDA_DEQUANTIZE_BLOCK_SIZE 256
|
#define CUDA_DEQUANTIZE_BLOCK_SIZE 256
|
||||||
#define CUDA_QUANTIZE_BLOCK_SIZE 256
|
#define CUDA_QUANTIZE_BLOCK_SIZE 256
|
||||||
|
|||||||
@@ -340,7 +340,7 @@ inline bool getEnvEnablePDL() {
|
|||||||
#define CEILDIV(x, y) (((x) + (y) - 1) / (y))
|
#define CEILDIV(x, y) (((x) + (y) - 1) / (y))
|
||||||
|
|
||||||
#ifndef USE_ROCM
|
#ifndef USE_ROCM
|
||||||
#define WARP_SIZE 32
|
#define WARP_SIZE 64
|
||||||
#else
|
#else
|
||||||
#if defined(__GFX9__) || !defined(__HIP_DEVICE_COMPILE__)
|
#if defined(__GFX9__) || !defined(__HIP_DEVICE_COMPILE__)
|
||||||
#define WARP_SIZE 64
|
#define WARP_SIZE 64
|
||||||
|
|||||||
Reference in New Issue
Block a user