Compare commits
2 Commits
d2fdeac22f
...
v0.5.4
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b4dff7f5ef | ||
|
|
c0352f4aab |
@@ -839,12 +839,10 @@ class BenchmarkMetrics:
|
|||||||
mean_ttft_ms: float
|
mean_ttft_ms: float
|
||||||
median_ttft_ms: float
|
median_ttft_ms: float
|
||||||
std_ttft_ms: float
|
std_ttft_ms: float
|
||||||
p95_ttft_ms: float
|
|
||||||
p99_ttft_ms: float
|
p99_ttft_ms: float
|
||||||
mean_tpot_ms: float
|
mean_tpot_ms: float
|
||||||
median_tpot_ms: float
|
median_tpot_ms: float
|
||||||
std_tpot_ms: float
|
std_tpot_ms: float
|
||||||
p95_tpot_ms: float
|
|
||||||
p99_tpot_ms: float
|
p99_tpot_ms: float
|
||||||
mean_itl_ms: float
|
mean_itl_ms: float
|
||||||
median_itl_ms: float
|
median_itl_ms: float
|
||||||
@@ -1667,12 +1665,10 @@ def calculate_metrics(
|
|||||||
* 1000, # ttfts is empty if streaming is not supported by backend
|
* 1000, # ttfts is empty if streaming is not supported by backend
|
||||||
median_ttft_ms=np.median(ttfts or 0) * 1000,
|
median_ttft_ms=np.median(ttfts or 0) * 1000,
|
||||||
std_ttft_ms=np.std(ttfts or 0) * 1000,
|
std_ttft_ms=np.std(ttfts or 0) * 1000,
|
||||||
p95_ttft_ms=np.percentile(ttfts or 0, 95) * 1000,
|
|
||||||
p99_ttft_ms=np.percentile(ttfts or 0, 99) * 1000,
|
p99_ttft_ms=np.percentile(ttfts or 0, 99) * 1000,
|
||||||
mean_tpot_ms=np.mean(tpots or 0) * 1000,
|
mean_tpot_ms=np.mean(tpots or 0) * 1000,
|
||||||
median_tpot_ms=np.median(tpots or 0) * 1000,
|
median_tpot_ms=np.median(tpots or 0) * 1000,
|
||||||
std_tpot_ms=np.std(tpots or 0) * 1000,
|
std_tpot_ms=np.std(tpots or 0) * 1000,
|
||||||
p95_tpot_ms=np.percentile(tpots or 0, 95) * 1000,
|
|
||||||
p99_tpot_ms=np.percentile(tpots or 0, 99) * 1000,
|
p99_tpot_ms=np.percentile(tpots or 0, 99) * 1000,
|
||||||
mean_itl_ms=np.mean(itls or 0) * 1000,
|
mean_itl_ms=np.mean(itls or 0) * 1000,
|
||||||
median_itl_ms=np.median(itls or 0) * 1000,
|
median_itl_ms=np.median(itls or 0) * 1000,
|
||||||
@@ -1978,12 +1974,6 @@ async def benchmark(
|
|||||||
print("{:<40} {:<10.2f}".format("Mean TTFT (ms):", metrics.mean_ttft_ms))
|
print("{:<40} {:<10.2f}".format("Mean TTFT (ms):", metrics.mean_ttft_ms))
|
||||||
print("{:<40} {:<10.2f}".format("Median TTFT (ms):", metrics.median_ttft_ms))
|
print("{:<40} {:<10.2f}".format("Median TTFT (ms):", metrics.median_ttft_ms))
|
||||||
print("{:<40} {:<10.2f}".format("P99 TTFT (ms):", metrics.p99_ttft_ms))
|
print("{:<40} {:<10.2f}".format("P99 TTFT (ms):", metrics.p99_ttft_ms))
|
||||||
print("{:<40} {:<10.2f}".format("P95 TTFT (ms):", metrics.p95_ttft_ms))
|
|
||||||
print("{s:{c}^{n}}".format(s="Time per Output Token (excl. 1st token)", n=50, c="-"))
|
|
||||||
print("{:<40} {:<10.2f}".format("Mean TPOT (ms):", metrics.mean_tpot_ms))
|
|
||||||
print("{:<40} {:<10.2f}".format("Median TPOT (ms):", metrics.median_tpot_ms))
|
|
||||||
print("{:<40} {:<10.2f}".format("P99 TPOT (ms):", metrics.p99_tpot_ms))
|
|
||||||
print("{:<40} {:<10.2f}".format("P95 TPOT (ms):", metrics.p95_tpot_ms))
|
|
||||||
print("{s:{c}^{n}}".format(s="Inter-Token Latency", n=50, c="-"))
|
print("{s:{c}^{n}}".format(s="Inter-Token Latency", n=50, c="-"))
|
||||||
print("{:<40} {:<10.2f}".format("Mean ITL (ms):", metrics.mean_itl_ms))
|
print("{:<40} {:<10.2f}".format("Mean ITL (ms):", metrics.mean_itl_ms))
|
||||||
print("{:<40} {:<10.2f}".format("Median ITL (ms):", metrics.median_itl_ms))
|
print("{:<40} {:<10.2f}".format("Median ITL (ms):", metrics.median_itl_ms))
|
||||||
|
|||||||
@@ -22,11 +22,9 @@ use_vllm_custom_allreduce = get_bool_env_var(
|
|||||||
|
|
||||||
if not is_hpu():
|
if not is_hpu():
|
||||||
# ROCm does not use vllm custom allreduce
|
# ROCm does not use vllm custom allreduce
|
||||||
# if use_vllm_custom_allreduce and not is_hip():
|
if use_vllm_custom_allreduce and not is_hip():
|
||||||
if use_vllm_custom_allreduce:
|
|
||||||
try:
|
try:
|
||||||
import vllm._C # noqa: F401
|
import vllm._C # noqa: F401
|
||||||
print("[DEBUG] ✅ Using vLLM custom allreduce (vllm._C successfully imported)")
|
|
||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
logger.warning("Failed to import from vllm._C with %r", e)
|
logger.warning("Failed to import from vllm._C with %r", e)
|
||||||
else:
|
else:
|
||||||
@@ -36,11 +34,9 @@ if not is_hpu():
|
|||||||
logger.warning("Failed to import from custom_ar with %r", e)
|
logger.warning("Failed to import from custom_ar with %r", e)
|
||||||
|
|
||||||
|
|
||||||
# if not is_hip() and not is_npu():
|
if not is_hip() and not is_npu():
|
||||||
if not is_npu():
|
|
||||||
if use_vllm_custom_allreduce:
|
if use_vllm_custom_allreduce:
|
||||||
custom_op = torch.ops._C_custom_ar
|
custom_op = torch.ops._C_custom_ar
|
||||||
print("[DEBUG] ✅ custom_op = torch.ops._C_custom_ar (vLLM path active)")
|
|
||||||
else:
|
else:
|
||||||
custom_op = sgl_kernel.allreduce
|
custom_op = sgl_kernel.allreduce
|
||||||
|
|
||||||
|
|||||||
@@ -614,8 +614,6 @@ class ModelConfig:
|
|||||||
"petit_nvfp4",
|
"petit_nvfp4",
|
||||||
"quark",
|
"quark",
|
||||||
"mxfp4",
|
"mxfp4",
|
||||||
"slimquant_w4a8_marlin",
|
|
||||||
"w8a8_int8",
|
|
||||||
]
|
]
|
||||||
optimized_quantization_methods = [
|
optimized_quantization_methods = [
|
||||||
"fp8",
|
"fp8",
|
||||||
@@ -635,7 +633,6 @@ class ModelConfig:
|
|||||||
"qoq",
|
"qoq",
|
||||||
"w4afp8",
|
"w4afp8",
|
||||||
"petit_nvfp4",
|
"petit_nvfp4",
|
||||||
"slimquant_w4a8_marlin",
|
|
||||||
]
|
]
|
||||||
compatible_quantization_methods = {
|
compatible_quantization_methods = {
|
||||||
"modelopt_fp4": ["modelopt"],
|
"modelopt_fp4": ["modelopt"],
|
||||||
|
|||||||
@@ -27,8 +27,7 @@ _is_hip = is_hip()
|
|||||||
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# if ops.use_vllm_custom_allreduce and not _is_hip:
|
if ops.use_vllm_custom_allreduce and not _is_hip:
|
||||||
if ops.use_vllm_custom_allreduce:
|
|
||||||
# Use vLLM custom allreduce
|
# Use vLLM custom allreduce
|
||||||
ops.meta_size()
|
ops.meta_size()
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -169,14 +169,6 @@ class RMSNorm(CustomOp):
|
|||||||
try:
|
try:
|
||||||
output = torch.empty_like(x)
|
output = torch.empty_like(x)
|
||||||
residual_out = torch.empty_like(x)
|
residual_out = torch.empty_like(x)
|
||||||
fused_add_rms_norm(
|
|
||||||
x,
|
|
||||||
residual,
|
|
||||||
self.weight.data,
|
|
||||||
self.variance_epsilon,
|
|
||||||
)
|
|
||||||
return x, residual
|
|
||||||
except TypeError:
|
|
||||||
fused_add_rms_norm(
|
fused_add_rms_norm(
|
||||||
output,
|
output,
|
||||||
x,
|
x,
|
||||||
@@ -186,7 +178,14 @@ class RMSNorm(CustomOp):
|
|||||||
self.variance_epsilon,
|
self.variance_epsilon,
|
||||||
)
|
)
|
||||||
return output, residual_out
|
return output, residual_out
|
||||||
|
except TypeError:
|
||||||
|
fused_add_rms_norm(
|
||||||
|
x,
|
||||||
|
residual,
|
||||||
|
self.weight.data,
|
||||||
|
self.variance_epsilon,
|
||||||
|
)
|
||||||
|
return x, residual
|
||||||
|
|
||||||
out = torch.empty_like(x)
|
out = torch.empty_like(x)
|
||||||
rms_norm(out, x, self.weight.data, self.variance_epsilon)
|
rms_norm(out, x, self.weight.data, self.variance_epsilon)
|
||||||
|
|||||||
0
python/sglang/srt/layers/moe/ep_moe/layer.py
Normal file → Executable file
0
python/sglang/srt/layers/moe/ep_moe/layer.py
Normal file → Executable file
@@ -14,10 +14,9 @@ from sglang.srt.layers.quantization.fp8_kernel import (
|
|||||||
)
|
)
|
||||||
from sglang.srt.layers.quantization.int8_kernel import (
|
from sglang.srt.layers.quantization.int8_kernel import (
|
||||||
per_token_group_quant_int8,
|
per_token_group_quant_int8,
|
||||||
# per_token_quant_int8,
|
per_token_quant_int8,
|
||||||
sglang_per_token_group_quant_int8,
|
sglang_per_token_group_quant_int8,
|
||||||
)
|
)
|
||||||
from lmslim.layers.gemm.int8_utils import per_token_quant_int8
|
|
||||||
from sglang.srt.utils import (
|
from sglang.srt.utils import (
|
||||||
cpu_has_amx_support,
|
cpu_has_amx_support,
|
||||||
get_bool_env_var,
|
get_bool_env_var,
|
||||||
|
|||||||
0
python/sglang/srt/layers/moe/token_dispatcher/deepep.py
Normal file → Executable file
0
python/sglang/srt/layers/moe/token_dispatcher/deepep.py
Normal file → Executable file
@@ -57,7 +57,6 @@ from sglang.srt.layers.quantization.qoq import QoQConfig
|
|||||||
from sglang.srt.layers.quantization.w4afp8 import W4AFp8Config
|
from sglang.srt.layers.quantization.w4afp8 import W4AFp8Config
|
||||||
from sglang.srt.layers.quantization.w8a8_fp8 import W8A8Fp8Config
|
from sglang.srt.layers.quantization.w8a8_fp8 import W8A8Fp8Config
|
||||||
from sglang.srt.layers.quantization.w8a8_int8 import W8A8Int8Config
|
from sglang.srt.layers.quantization.w8a8_int8 import W8A8Int8Config
|
||||||
from sglang.srt.layers.quantization.slimquant_w4a8_marlin import SlimQuantW4A8Int8MarlinConfig
|
|
||||||
from sglang.srt.utils import is_cuda, is_hip, mxfp_supported
|
from sglang.srt.utils import is_cuda, is_hip, mxfp_supported
|
||||||
|
|
||||||
_is_mxfp_supported = mxfp_supported()
|
_is_mxfp_supported = mxfp_supported()
|
||||||
@@ -84,7 +83,6 @@ BASE_QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
|
|||||||
"w4afp8": W4AFp8Config,
|
"w4afp8": W4AFp8Config,
|
||||||
"petit_nvfp4": PetitNvFp4Config,
|
"petit_nvfp4": PetitNvFp4Config,
|
||||||
"fbgemm_fp8": FBGEMMFp8Config,
|
"fbgemm_fp8": FBGEMMFp8Config,
|
||||||
"slimquant_w4a8_marlin":SlimQuantW4A8Int8MarlinConfig,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
from typing import Any, Callable, Dict, List, Optional
|
from typing import Any, Callable, Dict, List, Optional
|
||||||
# from sglang.srt.layers.moe.token_dispatcher.base import CombineInput
|
from sglang.srt.layers.moe.token_dispatcher.base import CombineInput
|
||||||
|
from sglang.srt.layers.moe.token_dispatcher.standard import StandardCombineInput, StandardDispatchOutput
|
||||||
import torch
|
import torch
|
||||||
from sglang.srt import _custom_ops as ops
|
from sglang.srt import _custom_ops as ops
|
||||||
from sglang.srt.utils import set_weight_attrs
|
from sglang.srt.utils import set_weight_attrs
|
||||||
@@ -218,9 +218,8 @@ class SlimQuantW4A8Int8MarlinMoEMethod:
|
|||||||
def apply(
|
def apply(
|
||||||
self,
|
self,
|
||||||
layer: torch.nn.Module,
|
layer: torch.nn.Module,
|
||||||
dispatch_output,
|
dispatch_output: StandardDispatchOutput,
|
||||||
) :
|
) -> CombineInput:
|
||||||
from sglang.srt.layers.moe.token_dispatcher.standard import StandardCombineInput
|
|
||||||
x = dispatch_output.hidden_states
|
x = dispatch_output.hidden_states
|
||||||
topk_output = dispatch_output.topk_output
|
topk_output = dispatch_output.topk_output
|
||||||
from sglang.srt.layers.moe.topk import apply_topk_weights_cpu
|
from sglang.srt.layers.moe.topk import apply_topk_weights_cpu
|
||||||
@@ -242,7 +241,7 @@ class SlimQuantW4A8Int8MarlinMoEMethod:
|
|||||||
use_int4_w4a8=True,
|
use_int4_w4a8=True,
|
||||||
per_channel_quant=True,
|
per_channel_quant=True,
|
||||||
activation=layer.moe_runner_config.activation,
|
activation=layer.moe_runner_config.activation,
|
||||||
# expert_map=layer.expert_map_gpu,
|
expert_map=layer.expert_map_gpu,
|
||||||
apply_router_weight_on_input=self.moe_runner_config.apply_router_weight_on_input,
|
apply_router_weight_on_input=self.moe_runner_config.apply_router_weight_on_input,
|
||||||
global_num_experts=layer.moe_runner_config.num_experts,
|
global_num_experts=layer.moe_runner_config.num_experts,
|
||||||
w1_scale=(layer.w13_weight_scale),
|
w1_scale=(layer.w13_weight_scale),
|
||||||
|
|||||||
@@ -1,92 +0,0 @@
|
|||||||
import torch
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
try:
|
|
||||||
from lightop import awq_marlin_repack_w4a8
|
|
||||||
use_lightop = False
|
|
||||||
except Exception:
|
|
||||||
use_lightop = False
|
|
||||||
|
|
||||||
def unpack_int8_to_int4(tensor_int8: torch.Tensor) -> torch.Tensor:
|
|
||||||
"""
|
|
||||||
将[N, K//2]大小的torch.int8 Tensor,转换为[N, K]大小的torch.int32 Tensor。
|
|
||||||
每个int8包含两个int4,分别提取到int32的低4位,其余位为0。
|
|
||||||
|
|
||||||
Args:
|
|
||||||
tensor_int8 (torch.Tensor): 输入张量,形状为[N, K//2],类型为torch.int8。
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
torch.Tensor: 输出张量,形状为[N, K],类型为torch.int32。
|
|
||||||
"""
|
|
||||||
if tensor_int8.dtype != torch.int8:
|
|
||||||
raise ValueError("Input tensor must be of type torch.int8")
|
|
||||||
|
|
||||||
N, K_half = tensor_int8.shape
|
|
||||||
tensor_uint8 = tensor_int8.to(torch.uint8)
|
|
||||||
high4 = tensor_uint8 & 0x0F
|
|
||||||
low4 = (tensor_uint8 >> 4) & 0x0F
|
|
||||||
unpacked = torch.empty((N, K_half * 2), dtype=torch.int32, device=tensor_int8.device)
|
|
||||||
unpacked[:, 0::2] = low4.to(torch.int32)
|
|
||||||
unpacked[:, 1::2] = high4.to(torch.int32)
|
|
||||||
|
|
||||||
return unpacked
|
|
||||||
|
|
||||||
def get_weight_perms(interleave: bool=True):
|
|
||||||
perm = []
|
|
||||||
for i in range(64):
|
|
||||||
|
|
||||||
for col in range(4):
|
|
||||||
cur_col = (i % 16) * 4 + col
|
|
||||||
for row in range(8):
|
|
||||||
cur_row = (i // 16) * 8 + row
|
|
||||||
cur_idx = cur_row * 64 + cur_col
|
|
||||||
perm.append(cur_idx)
|
|
||||||
|
|
||||||
perm = np.array(perm)
|
|
||||||
if interleave:
|
|
||||||
interleave = np.array([4, 0, 5, 1, 6, 2, 7, 3])
|
|
||||||
perm = perm.reshape((-1, 8))[:, interleave].ravel()
|
|
||||||
|
|
||||||
perm = torch.from_numpy(perm)
|
|
||||||
|
|
||||||
return perm
|
|
||||||
|
|
||||||
def marlin_weights(q_w,weight_perm,k_tile=32,n_tile=64,pack_factor=8):
|
|
||||||
size_k, size_n = q_w.shape
|
|
||||||
q_w = q_w.reshape((size_k // k_tile, k_tile, size_n // n_tile, n_tile))
|
|
||||||
q_w = q_w.permute((0, 2, 1, 3))
|
|
||||||
q_w = q_w.reshape((size_k // k_tile, size_n * k_tile))
|
|
||||||
q_w = q_w.reshape((-1, weight_perm.numel()))[:, weight_perm].reshape(q_w.shape)
|
|
||||||
|
|
||||||
orig_device = q_w.device
|
|
||||||
q_w = q_w.contiguous().to(torch.int32)
|
|
||||||
M, N = q_w.shape
|
|
||||||
assert N % pack_factor == 0, f"size_n ({N}) must be divisible by pack_factor ({pack_factor})"
|
|
||||||
q_packed = torch.zeros((M, N // pack_factor), dtype=torch.int32, device=orig_device)
|
|
||||||
for i in range(pack_factor):
|
|
||||||
q_packed += q_w[:, i::pack_factor] << (4 * i)
|
|
||||||
|
|
||||||
return q_packed
|
|
||||||
|
|
||||||
def w4a8_2_marlin_weight(w4a8_w):
|
|
||||||
full_w4a8_w = unpack_int8_to_int4(w4a8_w)
|
|
||||||
full_w4a8_w = full_w4a8_w.T
|
|
||||||
weight_perm = get_weight_perms()
|
|
||||||
marlin_q_w = marlin_weights(full_w4a8_w, weight_perm, k_tile=32, n_tile=64, pack_factor=8)
|
|
||||||
return marlin_q_w
|
|
||||||
|
|
||||||
def w4a8_weight_repack_impl(input):
|
|
||||||
if use_lightop:
|
|
||||||
size_batch = input.shape[0]
|
|
||||||
size_n = input.shape[1]
|
|
||||||
size_k = input.shape[2] * 2
|
|
||||||
output = torch.zeros((size_batch, size_k // 32, size_n * 4), device=input.device, dtype=torch.int32)
|
|
||||||
awq_marlin_repack_w4a8(input, output, size_batch, size_k, size_n)
|
|
||||||
else:
|
|
||||||
w_marlin_list = []
|
|
||||||
for e in range(input.shape[0]):
|
|
||||||
w_marlin_in = w4a8_2_marlin_weight(input[e])
|
|
||||||
w_marlin_list.append(w_marlin_in)
|
|
||||||
output = torch.stack(w_marlin_list, dim=0)
|
|
||||||
|
|
||||||
return output
|
|
||||||
@@ -22,8 +22,7 @@ from sglang.srt.layers.quantization.base_config import (
|
|||||||
QuantizeMethodBase,
|
QuantizeMethodBase,
|
||||||
)
|
)
|
||||||
from sglang.srt.layers.quantization.compressed_tensors.utils import should_ignore_layer
|
from sglang.srt.layers.quantization.compressed_tensors.utils import should_ignore_layer
|
||||||
# from sglang.srt.layers.quantization.int8_kernel import per_token_quant_int8
|
from sglang.srt.layers.quantization.int8_kernel import per_token_quant_int8
|
||||||
from lmslim.layers.gemm.int8_utils import per_token_quant_int8
|
|
||||||
from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod
|
from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod
|
||||||
from sglang.srt.utils import (
|
from sglang.srt.utils import (
|
||||||
apply_module_patch,
|
apply_module_patch,
|
||||||
@@ -40,8 +39,6 @@ if TYPE_CHECKING:
|
|||||||
CombineInput,
|
CombineInput,
|
||||||
StandardDispatchOutput,
|
StandardDispatchOutput,
|
||||||
)
|
)
|
||||||
from lmslim import quant_ops
|
|
||||||
|
|
||||||
|
|
||||||
_is_cuda = is_cuda()
|
_is_cuda = is_cuda()
|
||||||
_is_cpu_amx_available = cpu_has_amx_support()
|
_is_cpu_amx_available = cpu_has_amx_support()
|
||||||
@@ -408,7 +405,7 @@ class W8A8Int8LinearMethod(LinearMethodBase):
|
|||||||
x_scale_2d = x_scale.view(-1, x_scale.shape[-1])
|
x_scale_2d = x_scale.view(-1, x_scale.shape[-1])
|
||||||
output_shape = [*x_q.shape[:-1], layer.weight.shape[1]]
|
output_shape = [*x_q.shape[:-1], layer.weight.shape[1]]
|
||||||
|
|
||||||
output = quant_ops.triton_scaled_mm(
|
output = int8_scaled_mm(
|
||||||
x_q_2d,
|
x_q_2d,
|
||||||
layer.weight,
|
layer.weight,
|
||||||
x_scale_2d,
|
x_scale_2d,
|
||||||
|
|||||||
@@ -1618,7 +1618,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|||||||
self.seq_lens_cpu = self.seq_lens_cpu[keep_indices]
|
self.seq_lens_cpu = self.seq_lens_cpu[keep_indices]
|
||||||
self.orig_seq_lens = self.orig_seq_lens[keep_indices_device]
|
self.orig_seq_lens = self.orig_seq_lens[keep_indices_device]
|
||||||
self.out_cache_loc = None
|
self.out_cache_loc = None
|
||||||
self.seq_lens_sum = self.seq_lens.sum()
|
self.seq_lens_sum = self.seq_lens.sum().item()
|
||||||
self.output_ids = self.output_ids[keep_indices_device]
|
self.output_ids = self.output_ids[keep_indices_device]
|
||||||
self.return_logprob = any(req.return_logprob for req in self.reqs)
|
self.return_logprob = any(req.return_logprob for req in self.reqs)
|
||||||
if self.return_logprob:
|
if self.return_logprob:
|
||||||
|
|||||||
@@ -203,7 +203,7 @@ _is_xpu_xmx_available = xpu_has_xmx_support()
|
|||||||
SGLANG_CI_SMALL_KV_SIZE = os.getenv("SGLANG_CI_SMALL_KV_SIZE", None)
|
SGLANG_CI_SMALL_KV_SIZE = os.getenv("SGLANG_CI_SMALL_KV_SIZE", None)
|
||||||
|
|
||||||
# Detect stragger ranks in model loading
|
# Detect stragger ranks in model loading
|
||||||
UNBALANCED_MODEL_LOADING_TIMEOUT_S = 3600
|
UNBALANCED_MODEL_LOADING_TIMEOUT_S = 300
|
||||||
|
|
||||||
# the ratio of mamba cache pool size to max_running_requests, it will be safe when it is larger than 2 (yizhang2077)
|
# the ratio of mamba cache pool size to max_running_requests, it will be safe when it is larger than 2 (yizhang2077)
|
||||||
MAMBA_CACHE_SIZE_MAX_RUNNING_REQUESTS_RATIO = 3
|
MAMBA_CACHE_SIZE_MAX_RUNNING_REQUESTS_RATIO = 3
|
||||||
|
|||||||
@@ -1,58 +0,0 @@
|
|||||||
from ctypes import *
|
|
||||||
import os
|
|
||||||
import time
|
|
||||||
import threading
|
|
||||||
|
|
||||||
class Prof:
|
|
||||||
def __init__(self):
|
|
||||||
self.use_roctx = os.getenv('SGLANG_HIP_PROF') is not None
|
|
||||||
if self.use_roctx:
|
|
||||||
self.lib = cdll.LoadLibrary("libroctracer64.so")
|
|
||||||
self.lib.roctxRangePushA.argtypes = [c_char_p]
|
|
||||||
self.lib.roctxRangePushA.restype = c_int
|
|
||||||
self.lib.roctxRangePop.restype = c_int
|
|
||||||
self.tm = time.perf_counter()
|
|
||||||
self.push_depth = {}
|
|
||||||
|
|
||||||
def StartTracer(self):
|
|
||||||
if self.use_roctx:
|
|
||||||
if self.lib is None:
|
|
||||||
self.lib = cdll.LoadLibrary("libroctracer64.so")
|
|
||||||
self.lib.roctracer_start()
|
|
||||||
self.roc_tracer_flag = True
|
|
||||||
|
|
||||||
def StopTracer(self):
|
|
||||||
if self.use_roctx:
|
|
||||||
if self.lib is None:
|
|
||||||
self.lib = cdll.LoadLibrary("libroctracer64.so")
|
|
||||||
self.lib.roctracer_stop()
|
|
||||||
self.roc_tracer_flag = False
|
|
||||||
|
|
||||||
def thread_depth_add(self, num):
|
|
||||||
current_thread = threading.current_thread()
|
|
||||||
thread_id = current_thread.ident
|
|
||||||
if thread_id not in self.push_depth.keys():
|
|
||||||
self.push_depth[thread_id] = 0
|
|
||||||
if num < 0 and self.push_depth[thread_id] == 0:
|
|
||||||
return False
|
|
||||||
self.push_depth[thread_id] += num
|
|
||||||
return True
|
|
||||||
|
|
||||||
def ProfRangePush(self, message):
|
|
||||||
if profile.use_roctx and self.roc_tracer_flag:
|
|
||||||
profile.lib.roctxRangePushA(message.encode('utf-8'))
|
|
||||||
profile.lib.roctxRangePushA(message.encode('utf-8'))
|
|
||||||
self.thread_depth_add(1)
|
|
||||||
|
|
||||||
def ProfRangePop(self):
|
|
||||||
if profile.use_roctx and self.roc_tracer_flag:
|
|
||||||
if not self.thread_depth_add(-1):
|
|
||||||
return
|
|
||||||
profile.lib.roctxRangePop()
|
|
||||||
|
|
||||||
def ProfRangeAutoPush(self, message):
|
|
||||||
self.ProfRangePop()
|
|
||||||
self.ProfRangePush(message)
|
|
||||||
|
|
||||||
|
|
||||||
profile = Prof()
|
|
||||||
@@ -93,7 +93,6 @@ QUANTIZATION_CHOICES = [
|
|||||||
"w4afp8",
|
"w4afp8",
|
||||||
"mxfp4",
|
"mxfp4",
|
||||||
"compressed-tensors", # for Ktransformers
|
"compressed-tensors", # for Ktransformers
|
||||||
"slimquant_w4a8_marlin",
|
|
||||||
]
|
]
|
||||||
|
|
||||||
ATTENTION_BACKEND_CHOICES = [
|
ATTENTION_BACKEND_CHOICES = [
|
||||||
|
|||||||
@@ -25,7 +25,7 @@
|
|||||||
#define INTRIN_M 16
|
#define INTRIN_M 16
|
||||||
#define INTRIN_N 16
|
#define INTRIN_N 16
|
||||||
#define INTRIN_K 32
|
#define INTRIN_K 32
|
||||||
#define WARP_SIZE 64
|
#define WARP_SIZE 32
|
||||||
#define SMEM_PAD_A 0
|
#define SMEM_PAD_A 0
|
||||||
#define SMEM_PAD_B 0
|
#define SMEM_PAD_B 0
|
||||||
#define PACK_SIZE 16
|
#define PACK_SIZE 16
|
||||||
|
|||||||
@@ -25,7 +25,7 @@
|
|||||||
#define INTRIN_M 16
|
#define INTRIN_M 16
|
||||||
#define INTRIN_N 16
|
#define INTRIN_N 16
|
||||||
#define INTRIN_K 32
|
#define INTRIN_K 32
|
||||||
#define WARP_SIZE 64
|
#define WARP_SIZE 32
|
||||||
#define SMEM_PAD_A 0
|
#define SMEM_PAD_A 0
|
||||||
#define SMEM_PAD_B 0
|
#define SMEM_PAD_B 0
|
||||||
#define PACK_SIZE 16
|
#define PACK_SIZE 16
|
||||||
|
|||||||
@@ -5,7 +5,7 @@
|
|||||||
#include <cstdint>
|
#include <cstdint>
|
||||||
|
|
||||||
#ifndef USE_ROCM
|
#ifndef USE_ROCM
|
||||||
#define WARP_SIZE 64
|
#define WARP_SIZE 32
|
||||||
#include "pytorch_extension_utils.h"
|
#include "pytorch_extension_utils.h"
|
||||||
#else
|
#else
|
||||||
#include "pytorch_extension_utils_rocm.h"
|
#include "pytorch_extension_utils_rocm.h"
|
||||||
|
|||||||
@@ -3,7 +3,7 @@
|
|||||||
// copied from https://github.com/ggerganov/llama.cpp/blob/b2899/ggml-common.h
|
// copied from https://github.com/ggerganov/llama.cpp/blob/b2899/ggml-common.h
|
||||||
#define QK_K 256
|
#define QK_K 256
|
||||||
#define K_QUANTS_PER_ITERATION 2
|
#define K_QUANTS_PER_ITERATION 2
|
||||||
#define WARP_SIZE_GGUF 64
|
#define WARP_SIZE_GGUF 32
|
||||||
#define K_SCALE_SIZE 12
|
#define K_SCALE_SIZE 12
|
||||||
#define CUDA_DEQUANTIZE_BLOCK_SIZE 256
|
#define CUDA_DEQUANTIZE_BLOCK_SIZE 256
|
||||||
#define CUDA_QUANTIZE_BLOCK_SIZE 256
|
#define CUDA_QUANTIZE_BLOCK_SIZE 256
|
||||||
|
|||||||
@@ -340,7 +340,7 @@ inline bool getEnvEnablePDL() {
|
|||||||
#define CEILDIV(x, y) (((x) + (y) - 1) / (y))
|
#define CEILDIV(x, y) (((x) + (y) - 1) / (y))
|
||||||
|
|
||||||
#ifndef USE_ROCM
|
#ifndef USE_ROCM
|
||||||
#define WARP_SIZE 64
|
#define WARP_SIZE 32
|
||||||
#else
|
#else
|
||||||
#if defined(__GFX9__) || !defined(__HIP_DEVICE_COMPILE__)
|
#if defined(__GFX9__) || !defined(__HIP_DEVICE_COMPILE__)
|
||||||
#define WARP_SIZE 64
|
#define WARP_SIZE 64
|
||||||
|
|||||||
Reference in New Issue
Block a user