Compare commits
13 Commits
v0.5.4
...
d2fdeac22f
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
d2fdeac22f | ||
|
|
75cd34d172 | ||
|
|
8fc552638f | ||
|
|
eb4ba1c295 | ||
|
|
4b9b337b39 | ||
|
|
f6528b74be | ||
|
|
a5718531b7 | ||
|
|
c333f12547 | ||
|
|
f9a026ad2b | ||
|
|
b80ae5e9ff | ||
|
|
b091a7a5c9 | ||
|
|
143ec5f36c | ||
|
|
67510e0172 |
@@ -839,10 +839,12 @@ class BenchmarkMetrics:
|
||||
mean_ttft_ms: float
|
||||
median_ttft_ms: float
|
||||
std_ttft_ms: float
|
||||
p95_ttft_ms: float
|
||||
p99_ttft_ms: float
|
||||
mean_tpot_ms: float
|
||||
median_tpot_ms: float
|
||||
std_tpot_ms: float
|
||||
p95_tpot_ms: float
|
||||
p99_tpot_ms: float
|
||||
mean_itl_ms: float
|
||||
median_itl_ms: float
|
||||
@@ -1665,10 +1667,12 @@ def calculate_metrics(
|
||||
* 1000, # ttfts is empty if streaming is not supported by backend
|
||||
median_ttft_ms=np.median(ttfts or 0) * 1000,
|
||||
std_ttft_ms=np.std(ttfts or 0) * 1000,
|
||||
p95_ttft_ms=np.percentile(ttfts or 0, 95) * 1000,
|
||||
p99_ttft_ms=np.percentile(ttfts or 0, 99) * 1000,
|
||||
mean_tpot_ms=np.mean(tpots or 0) * 1000,
|
||||
median_tpot_ms=np.median(tpots or 0) * 1000,
|
||||
std_tpot_ms=np.std(tpots or 0) * 1000,
|
||||
p95_tpot_ms=np.percentile(tpots or 0, 95) * 1000,
|
||||
p99_tpot_ms=np.percentile(tpots or 0, 99) * 1000,
|
||||
mean_itl_ms=np.mean(itls or 0) * 1000,
|
||||
median_itl_ms=np.median(itls or 0) * 1000,
|
||||
@@ -1974,6 +1978,12 @@ async def benchmark(
|
||||
print("{:<40} {:<10.2f}".format("Mean TTFT (ms):", metrics.mean_ttft_ms))
|
||||
print("{:<40} {:<10.2f}".format("Median TTFT (ms):", metrics.median_ttft_ms))
|
||||
print("{:<40} {:<10.2f}".format("P99 TTFT (ms):", metrics.p99_ttft_ms))
|
||||
print("{:<40} {:<10.2f}".format("P95 TTFT (ms):", metrics.p95_ttft_ms))
|
||||
print("{s:{c}^{n}}".format(s="Time per Output Token (excl. 1st token)", n=50, c="-"))
|
||||
print("{:<40} {:<10.2f}".format("Mean TPOT (ms):", metrics.mean_tpot_ms))
|
||||
print("{:<40} {:<10.2f}".format("Median TPOT (ms):", metrics.median_tpot_ms))
|
||||
print("{:<40} {:<10.2f}".format("P99 TPOT (ms):", metrics.p99_tpot_ms))
|
||||
print("{:<40} {:<10.2f}".format("P95 TPOT (ms):", metrics.p95_tpot_ms))
|
||||
print("{s:{c}^{n}}".format(s="Inter-Token Latency", n=50, c="-"))
|
||||
print("{:<40} {:<10.2f}".format("Mean ITL (ms):", metrics.mean_itl_ms))
|
||||
print("{:<40} {:<10.2f}".format("Median ITL (ms):", metrics.median_itl_ms))
|
||||
|
||||
@@ -5,6 +5,15 @@ from typing import List, Optional, Tuple
|
||||
import torch
|
||||
|
||||
from sglang.srt.utils import get_bool_env_var, is_hip, is_hpu, is_npu
|
||||
try:
|
||||
from lmslim import quant_ops
|
||||
from lmslim import quant_tools
|
||||
except Exception:
|
||||
print("INFO: Please install lmslim if you want to infer gptq or awq or w8a8 model.\n")
|
||||
try:
|
||||
import lightop
|
||||
except Exception:
|
||||
print("INFO: Please install lightop if you want to infer awq of marlin.\n")
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
use_vllm_custom_allreduce = get_bool_env_var(
|
||||
@@ -13,9 +22,11 @@ use_vllm_custom_allreduce = get_bool_env_var(
|
||||
|
||||
if not is_hpu():
|
||||
# ROCm does not use vllm custom allreduce
|
||||
if use_vllm_custom_allreduce and not is_hip():
|
||||
# if use_vllm_custom_allreduce and not is_hip():
|
||||
if use_vllm_custom_allreduce:
|
||||
try:
|
||||
import vllm._C # noqa: F401
|
||||
print("[DEBUG] ✅ Using vLLM custom allreduce (vllm._C successfully imported)")
|
||||
except ImportError as e:
|
||||
logger.warning("Failed to import from vllm._C with %r", e)
|
||||
else:
|
||||
@@ -25,9 +36,11 @@ if not is_hpu():
|
||||
logger.warning("Failed to import from custom_ar with %r", e)
|
||||
|
||||
|
||||
if not is_hip() and not is_npu():
|
||||
# if not is_hip() and not is_npu():
|
||||
if not is_npu():
|
||||
if use_vllm_custom_allreduce:
|
||||
custom_op = torch.ops._C_custom_ar
|
||||
print("[DEBUG] ✅ custom_op = torch.ops._C_custom_ar (vLLM path active)")
|
||||
else:
|
||||
custom_op = sgl_kernel.allreduce
|
||||
|
||||
@@ -175,3 +188,25 @@ def mscclpp_allreduce(
|
||||
context: int, inp: torch.Tensor, out: torch.Tensor, nthreads: int, nblocks: int
|
||||
) -> None:
|
||||
return sgl_kernel.allreduce.mscclpp_allreduce(context, inp, out, nthreads, nblocks)
|
||||
|
||||
def triton_scaled_mm(a: torch.Tensor,
|
||||
b: torch.Tensor,
|
||||
scale_a: torch.Tensor,
|
||||
scale_b: torch.Tensor,
|
||||
out_dtype: torch.dtype,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
best_config:Optional[list] = None) -> torch.Tensor:
|
||||
|
||||
return quant_ops.triton_scaled_mm(a, b,scale_a,scale_b,out_dtype,bias,best_config)
|
||||
|
||||
def triton_int8_gemm_helper(m: int,
|
||||
n: int,
|
||||
k: int,
|
||||
per_token_act_quant: bool,
|
||||
per_out_channel_weight_quant: bool,
|
||||
use_bias: bool,
|
||||
out_dtype: type[torch.dtype] = torch.float16,
|
||||
device: str = "cuda:0",
|
||||
best_config:Optional[list] = None,
|
||||
repeat:Optional[int] = 2):
|
||||
return quant_tools.triton_int8_gemm_helper(m,n,k,per_token_act_quant,per_out_channel_weight_quant,use_bias,out_dtype,device,best_config,repeat)
|
||||
@@ -614,6 +614,8 @@ class ModelConfig:
|
||||
"petit_nvfp4",
|
||||
"quark",
|
||||
"mxfp4",
|
||||
"slimquant_w4a8_marlin",
|
||||
"w8a8_int8",
|
||||
]
|
||||
optimized_quantization_methods = [
|
||||
"fp8",
|
||||
@@ -633,6 +635,7 @@ class ModelConfig:
|
||||
"qoq",
|
||||
"w4afp8",
|
||||
"petit_nvfp4",
|
||||
"slimquant_w4a8_marlin",
|
||||
]
|
||||
compatible_quantization_methods = {
|
||||
"modelopt_fp4": ["modelopt"],
|
||||
|
||||
@@ -27,7 +27,8 @@ _is_hip = is_hip()
|
||||
|
||||
|
||||
try:
|
||||
if ops.use_vllm_custom_allreduce and not _is_hip:
|
||||
# if ops.use_vllm_custom_allreduce and not _is_hip:
|
||||
if ops.use_vllm_custom_allreduce:
|
||||
# Use vLLM custom allreduce
|
||||
ops.meta_size()
|
||||
else:
|
||||
|
||||
@@ -169,6 +169,14 @@ class RMSNorm(CustomOp):
|
||||
try:
|
||||
output = torch.empty_like(x)
|
||||
residual_out = torch.empty_like(x)
|
||||
fused_add_rms_norm(
|
||||
x,
|
||||
residual,
|
||||
self.weight.data,
|
||||
self.variance_epsilon,
|
||||
)
|
||||
return x, residual
|
||||
except TypeError:
|
||||
fused_add_rms_norm(
|
||||
output,
|
||||
x,
|
||||
@@ -178,14 +186,7 @@ class RMSNorm(CustomOp):
|
||||
self.variance_epsilon,
|
||||
)
|
||||
return output, residual_out
|
||||
except TypeError:
|
||||
fused_add_rms_norm(
|
||||
x,
|
||||
residual,
|
||||
self.weight.data,
|
||||
self.variance_epsilon,
|
||||
)
|
||||
return x, residual
|
||||
|
||||
|
||||
out = torch.empty_like(x)
|
||||
rms_norm(out, x, self.weight.data, self.variance_epsilon)
|
||||
|
||||
@@ -3,6 +3,7 @@ from __future__ import annotations
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
|
||||
|
||||
from sglang.srt.layers.quantization.slimquant_w4a8_marlin import SlimQuantW4A8Int8MarlinConfig
|
||||
import torch
|
||||
|
||||
from sglang.srt import single_batch_overlap
|
||||
@@ -54,7 +55,286 @@ if _use_aiter:
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DeepEPMoE(FusedMoE):
|
||||
# TODO(kaixih@nvidia): ideally we should merge this logic into
|
||||
# `fill_gateup_input_triton_kernel` to directly generate e8m0 scale.
|
||||
@torch.compile
|
||||
def _cast_to_e8m0_with_rounding_up(x: torch.Tensor) -> torch.Tensor:
|
||||
temp = x.to(torch.float32).view(torch.int32)
|
||||
exp = torch.bitwise_right_shift(temp, 23)
|
||||
mant = torch.bitwise_and(temp, 0x7FFFFF)
|
||||
is_ru = torch.logical_and(
|
||||
torch.logical_and((mant > 0), (exp != 0xFE)),
|
||||
~torch.logical_and((exp == 0), (mant <= 0x400000)),
|
||||
)
|
||||
exp = torch.where(is_ru, exp + 1, exp)
|
||||
new_x = exp.to(torch.uint8).view(torch.int)
|
||||
return new_x.transpose(1, 2).contiguous().transpose(1, 2)
|
||||
|
||||
|
||||
class EPMoE(FusedMoE):
|
||||
"""
|
||||
MoE Expert Parallel Impl
|
||||
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_experts: int,
|
||||
top_k: int,
|
||||
hidden_size: int,
|
||||
intermediate_size: int,
|
||||
layer_id: int,
|
||||
num_fused_shared_experts: int = 0,
|
||||
params_dtype: Optional[torch.dtype] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
activation: str = "silu",
|
||||
routed_scaling_factor: Optional[float] = None,
|
||||
gemm1_alpha: Optional[float] = None,
|
||||
gemm1_clamp_limit: Optional[float] = None,
|
||||
with_bias: bool = False,
|
||||
):
|
||||
super().__init__(
|
||||
num_experts=num_experts,
|
||||
hidden_size=hidden_size,
|
||||
intermediate_size=intermediate_size,
|
||||
num_fused_shared_experts=num_fused_shared_experts,
|
||||
layer_id=layer_id,
|
||||
top_k=top_k,
|
||||
params_dtype=params_dtype,
|
||||
quant_config=quant_config,
|
||||
prefix=prefix,
|
||||
activation=activation,
|
||||
# apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
routed_scaling_factor=routed_scaling_factor,
|
||||
gemm1_alpha=gemm1_alpha,
|
||||
gemm1_clamp_limit=gemm1_clamp_limit,
|
||||
with_bias=with_bias,
|
||||
)
|
||||
|
||||
self.intermediate_size = intermediate_size
|
||||
if isinstance(quant_config, Fp8Config):
|
||||
self.use_block_quant = getattr(self.quant_method, "block_quant", False)
|
||||
self.block_shape = (
|
||||
self.quant_method.quant_config.weight_block_size
|
||||
if self.use_block_quant
|
||||
else None
|
||||
)
|
||||
self.use_fp8_w8a8 = True
|
||||
self.fp8_dtype = torch.float8_e4m3fn
|
||||
self.activation_scheme = quant_config.activation_scheme
|
||||
self.use_w4a8_marlin = False
|
||||
elif isinstance(quant_config, SlimQuantW4A8Int8MarlinConfig):
|
||||
self.use_block_quant = getattr(self.quant_method, "block_quant", False)
|
||||
self.block_shape = (
|
||||
self.quant_method.quant_config.weight_block_size
|
||||
if self.use_block_quant
|
||||
else None
|
||||
)
|
||||
self.use_fp8_w8a8 = False
|
||||
self.activation_scheme = None
|
||||
self.use_w4a8_marlin = True
|
||||
else:
|
||||
self.use_fp8_w8a8 = False
|
||||
self.use_block_quant = False
|
||||
self.block_shape = None
|
||||
self.activation_scheme = None
|
||||
self.use_w4a8_marlin = False
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor, topk_output: TopKOutput):
|
||||
if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and self.use_fp8_w8a8:
|
||||
return self.forward_deepgemm(hidden_states, topk_output)
|
||||
else:
|
||||
return super().forward(hidden_states, topk_output)
|
||||
|
||||
def forward_deepgemm(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
topk_output: TopKOutput,
|
||||
):
|
||||
|
||||
self.w13_weight_fp8 = (
|
||||
self.w13_weight,
|
||||
(
|
||||
self.w13_weight_scale_inv
|
||||
if self.use_block_quant
|
||||
else self.w13_weight_scale
|
||||
),
|
||||
)
|
||||
self.w2_weight_fp8 = (
|
||||
self.w2_weight,
|
||||
self.w2_weight_scale_inv if self.use_block_quant else self.w2_weight_scale,
|
||||
)
|
||||
|
||||
assert self.quant_method is not None
|
||||
assert self.moe_runner_config.activation == "silu"
|
||||
|
||||
hidden_states_shape = hidden_states.shape
|
||||
hidden_states_dtype = hidden_states.dtype
|
||||
hidden_states_device = hidden_states.device
|
||||
|
||||
topk_weights, topk_ids, _ = topk_output
|
||||
|
||||
if not self.use_block_quant:
|
||||
# Convert per-tensor quant to per-block quant by repeating scales for forward_deepgemm
|
||||
scale_block_size = 128
|
||||
w13_weight_scale_n = 2 * (
|
||||
(self.intermediate_size + scale_block_size - 1) // scale_block_size
|
||||
)
|
||||
w13_weight_scale_k = (
|
||||
hidden_states_shape[-1] + scale_block_size - 1
|
||||
) // scale_block_size
|
||||
w13_weight_scale = (
|
||||
self.w13_weight_scale.unsqueeze(1)
|
||||
.repeat_interleave(w13_weight_scale_n, dim=1)
|
||||
.unsqueeze(2)
|
||||
.repeat_interleave(w13_weight_scale_k, dim=2)
|
||||
)
|
||||
self.w13_weight_fp8 = (
|
||||
self.w13_weight,
|
||||
w13_weight_scale,
|
||||
)
|
||||
w2_weight_scale_n = (
|
||||
hidden_states_shape[-1] + scale_block_size - 1
|
||||
) // scale_block_size
|
||||
w2_weight_scale_k = (
|
||||
self.intermediate_size + scale_block_size - 1
|
||||
) // scale_block_size
|
||||
w2_weight_scale = (
|
||||
self.w2_weight_scale.unsqueeze(1)
|
||||
.repeat_interleave(w2_weight_scale_n, dim=1)
|
||||
.unsqueeze(2)
|
||||
.repeat_interleave(w2_weight_scale_k, dim=2)
|
||||
)
|
||||
self.w2_weight_fp8 = (
|
||||
self.w2_weight,
|
||||
w2_weight_scale,
|
||||
)
|
||||
|
||||
# PreReorder
|
||||
m_max, masked_m, expected_m, src2dst, gateup_input, gateup_input_scale = (
|
||||
moe_ep_deepgemm_preprocess(
|
||||
topk_ids,
|
||||
self.num_experts,
|
||||
hidden_states,
|
||||
self.top_k,
|
||||
self.start_expert_id,
|
||||
self.end_expert_id,
|
||||
self.block_shape,
|
||||
)
|
||||
)
|
||||
|
||||
dispose_tensor(hidden_states)
|
||||
|
||||
if deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0:
|
||||
b, s_mn, s_k = gateup_input_scale.shape
|
||||
assert (
|
||||
s_mn % 4 == 0 and s_k % 4 == 0
|
||||
), f"scales must be aligned to 4, but got ({b}, {s_mn}, {s_k})"
|
||||
|
||||
# GroupGemm-0
|
||||
gateup_input_fp8 = (
|
||||
gateup_input,
|
||||
(
|
||||
_cast_to_e8m0_with_rounding_up(gateup_input_scale)
|
||||
if deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0
|
||||
else deep_gemm_wrapper.get_mn_major_tma_aligned_tensor(
|
||||
gateup_input_scale
|
||||
)
|
||||
),
|
||||
)
|
||||
num_groups, m, k = gateup_input_fp8[0].size()
|
||||
n = self.w13_weight.size(1)
|
||||
gateup_output = torch.empty(
|
||||
(num_groups, m, n), device=hidden_states_device, dtype=torch.bfloat16
|
||||
)
|
||||
deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_masked(
|
||||
gateup_input_fp8,
|
||||
self.w13_weight_fp8,
|
||||
gateup_output,
|
||||
masked_m,
|
||||
expected_m,
|
||||
)
|
||||
del gateup_input
|
||||
del gateup_input_fp8
|
||||
|
||||
# Act
|
||||
down_input = torch.empty(
|
||||
(
|
||||
gateup_output.shape[0],
|
||||
gateup_output.shape[1],
|
||||
gateup_output.shape[2] // 2,
|
||||
),
|
||||
device=hidden_states_device,
|
||||
dtype=self.fp8_dtype,
|
||||
)
|
||||
scale_block_size = 128
|
||||
down_input_scale = torch.empty(
|
||||
(
|
||||
gateup_output.shape[0],
|
||||
gateup_output.shape[1],
|
||||
gateup_output.shape[2] // 2 // scale_block_size,
|
||||
),
|
||||
device=hidden_states_device,
|
||||
dtype=torch.float32,
|
||||
)
|
||||
silu_and_mul_masked_post_quant_fwd(
|
||||
gateup_output,
|
||||
down_input,
|
||||
down_input_scale,
|
||||
scale_block_size,
|
||||
masked_m,
|
||||
scale_ue8m0=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0,
|
||||
)
|
||||
del gateup_output
|
||||
|
||||
# GroupGemm-1
|
||||
n = self.w2_weight.size(1)
|
||||
down_input_fp8 = (
|
||||
down_input,
|
||||
(
|
||||
down_input_scale
|
||||
if deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0
|
||||
else deep_gemm_wrapper.get_mn_major_tma_aligned_tensor(down_input_scale)
|
||||
),
|
||||
)
|
||||
down_output = torch.empty(
|
||||
(num_groups, m, n), device=hidden_states_device, dtype=torch.bfloat16
|
||||
)
|
||||
deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_masked(
|
||||
down_input_fp8,
|
||||
self.w2_weight_fp8,
|
||||
down_output,
|
||||
masked_m,
|
||||
expected_m,
|
||||
)
|
||||
del down_input
|
||||
del down_input_fp8
|
||||
|
||||
# PostReorder
|
||||
output = torch.empty(
|
||||
hidden_states_shape, dtype=hidden_states_dtype, device=hidden_states_device
|
||||
)
|
||||
post_reorder_triton_kernel[(hidden_states_shape[0],)](
|
||||
down_output,
|
||||
output,
|
||||
src2dst,
|
||||
topk_ids,
|
||||
topk_weights,
|
||||
self.start_expert_id,
|
||||
self.end_expert_id,
|
||||
self.top_k,
|
||||
hidden_states_shape[1],
|
||||
m_max * self.start_expert_id,
|
||||
BLOCK_SIZE=512,
|
||||
)
|
||||
if self.moe_runner_config.routed_scaling_factor is not None:
|
||||
output *= self.moe_runner_config.routed_scaling_factor
|
||||
return output
|
||||
|
||||
|
||||
class DeepEPMoE(EPMoE):
|
||||
"""
|
||||
MoE Expert Parallel Impl based on DeepEP (https://github.com/deepseek-ai/DeepEP/tree/main)
|
||||
Mooncake EP shares the same class, as they expose the same interface.
|
||||
@@ -106,11 +386,28 @@ class DeepEPMoE(FusedMoE):
|
||||
|
||||
self.deepep_mode = get_deepep_mode()
|
||||
|
||||
if self.deepep_mode.enable_low_latency() and not _is_npu:
|
||||
# NPU supports low_latency deepep without deepgemm
|
||||
assert (
|
||||
deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM
|
||||
), f"DeepEP {self.deepep_mode} mode requires deep_gemm"
|
||||
# TODO: move to the beginning of the file
|
||||
from sglang.srt.distributed.parallel_state import get_tp_group
|
||||
from sglang.srt.two_batch_overlap import MaybeTboDeepEPDispatcher
|
||||
|
||||
self.deepep_dispatcher = MaybeTboDeepEPDispatcher(
|
||||
group=get_tp_group().device_group,
|
||||
router_topk=self.top_k,
|
||||
permute_fusion=True,
|
||||
num_experts=self.num_experts,
|
||||
num_local_experts=self.num_local_experts,
|
||||
hidden_size=hidden_size,
|
||||
params_dtype=params_dtype,
|
||||
deepep_mode=self.deepep_mode,
|
||||
async_finish=True, # TODO
|
||||
return_recv_hook=True,
|
||||
)
|
||||
|
||||
# if self.deepep_mode.enable_low_latency() and not _is_npu:
|
||||
# # NPU supports low_latency deepep without deepgemm
|
||||
# assert (
|
||||
# deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM
|
||||
# ), f"DeepEP {self.deepep_mode} mode requires deep_gemm"
|
||||
if _use_aiter:
|
||||
# expert_mask is of size (self.num_local_experts + 1),
|
||||
# the extra 1 is for invalid rank_id (in original deepep, the invalid rank_id is -1, but aiter does not allow -1, we use a mask to make those ids invalid)
|
||||
@@ -124,23 +421,23 @@ class DeepEPMoE(FusedMoE):
|
||||
)
|
||||
# the last one is invalid rank_id
|
||||
self.expert_mask[:-1] = 1
|
||||
elif not _is_npu:
|
||||
self.w13_weight_fp8 = (
|
||||
self.w13_weight,
|
||||
(
|
||||
self.w13_weight_scale_inv
|
||||
if self.use_block_quant or self.use_w4afp8
|
||||
else self.w13_weight_scale
|
||||
),
|
||||
)
|
||||
self.w2_weight_fp8 = (
|
||||
self.w2_weight,
|
||||
(
|
||||
self.w2_weight_scale_inv
|
||||
if self.use_block_quant or self.use_w4afp8
|
||||
else self.w2_weight_scale
|
||||
),
|
||||
)
|
||||
# elif not _is_npu:
|
||||
# self.w13_weight_fp8 = (
|
||||
# self.w13_weight,
|
||||
# (
|
||||
# self.w13_weight_scale_inv
|
||||
# if self.use_block_quant
|
||||
# else self.w13_weight_scale
|
||||
# ),
|
||||
# )
|
||||
# self.w2_weight_fp8 = (
|
||||
# self.w2_weight,
|
||||
# (
|
||||
# self.w2_weight_scale_inv
|
||||
# if self.use_block_quant
|
||||
# else self.w2_weight_scale
|
||||
# ),
|
||||
# )
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -187,10 +484,15 @@ class DeepEPMoE(FusedMoE):
|
||||
assert DispatchOutputChecker.format_is_deepep(dispatch_output)
|
||||
return self.forward_npu(dispatch_output)
|
||||
if DispatchOutputChecker.format_is_deepep_normal(dispatch_output):
|
||||
if self.use_w4afp8:
|
||||
return self.forward_cutlass_w4afp8(dispatch_output)
|
||||
assert deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and self.use_fp8_w8a8
|
||||
return self.forward_deepgemm_contiguous(dispatch_output)
|
||||
#assert deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and self.use_fp8_w8a8
|
||||
if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and self.use_fp8_w8a8:
|
||||
return self.forward_deepgemm_contiguous(dispatch_output)
|
||||
elif self.use_w4a8_marlin:
|
||||
return self.forward_deepgemm_w4a8_marlin_contiguous(dispatch_output)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Dispatch output is not supported"
|
||||
)
|
||||
elif DispatchOutputChecker.format_is_deepep_ll(dispatch_output):
|
||||
if (
|
||||
get_moe_runner_backend().is_flashinfer_cutedsl()
|
||||
@@ -255,6 +557,34 @@ class DeepEPMoE(FusedMoE):
|
||||
expert_mask=self.expert_mask,
|
||||
)
|
||||
|
||||
def forward_deepgemm_w4a8_marlin_contiguous(
|
||||
self,
|
||||
dispatch_output: DeepEPNormalOutput,
|
||||
):
|
||||
hidden_states_int8, topk_idx, topk_weights, num_recv_tokens_per_expert = (
|
||||
dispatch_output
|
||||
)
|
||||
assert self.quant_method is not None
|
||||
assert self.moe_runner_config.activation == "silu"
|
||||
# if num_recv_tokens_per_expert is None:
|
||||
return hidden_states_int8.bfloat16()
|
||||
# expert_output = self.quant_method.apply_ep(
|
||||
# layer=self,
|
||||
# x=dispatch_output,
|
||||
# topk_weights=topk_weights,
|
||||
# topk_ids=topk_idx,
|
||||
# global_num_experts=self.global_num_experts,
|
||||
# expert_map=self.expert_map,
|
||||
# activation=self.activation,
|
||||
# apply_router_weight_on_input=self.apply_router_weight_on_input,
|
||||
# use_nn_moe=self.use_nn_moe,
|
||||
# num_local_tokens=dispatch_recv_num_token,
|
||||
# config_select_bs=hidden_states.shape[0],
|
||||
# scales=dispatch_scales if self.use_int8_dispatch else None
|
||||
# # routed_scaling_factor=self.routed_scaling_factor,
|
||||
# )
|
||||
# return expert_output
|
||||
|
||||
def forward_deepgemm_contiguous(
|
||||
self,
|
||||
dispatch_output: DeepEPNormalOutput,
|
||||
|
||||
@@ -14,9 +14,10 @@ from sglang.srt.layers.quantization.fp8_kernel import (
|
||||
)
|
||||
from sglang.srt.layers.quantization.int8_kernel import (
|
||||
per_token_group_quant_int8,
|
||||
per_token_quant_int8,
|
||||
# per_token_quant_int8,
|
||||
sglang_per_token_group_quant_int8,
|
||||
)
|
||||
from lmslim.layers.gemm.int8_utils import per_token_quant_int8
|
||||
from sglang.srt.utils import (
|
||||
cpu_has_amx_support,
|
||||
get_bool_env_var,
|
||||
|
||||
@@ -460,11 +460,32 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
|
||||
overlap_args: Optional["CombineOverlapArgs"],
|
||||
):
|
||||
|
||||
if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM or _use_aiter or _is_npu:
|
||||
output = hidden_states
|
||||
else:
|
||||
raise NotImplementedError() # triton runner was supported but it's temporarily disabled
|
||||
|
||||
#if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM or _use_aiter or _is_npu:
|
||||
output = hidden_states
|
||||
# else:
|
||||
# if hidden_states.shape[0] > 0:
|
||||
# num_tokens = self.src2dst.shape[0] // self.router_topk
|
||||
# output = torch.empty(
|
||||
# (num_tokens, hidden_states.shape[1]),
|
||||
# device=hidden_states.device,
|
||||
# dtype=hidden_states.dtype,
|
||||
# )
|
||||
# deepep_post_reorder_triton_kernel[(num_tokens,)](
|
||||
# hidden_states,
|
||||
# output,
|
||||
# self.src2dst,
|
||||
# topk_idx,
|
||||
# topk_weights,
|
||||
# self.router_topk,
|
||||
# hidden_states.shape[1],
|
||||
# BLOCK_SIZE=512,
|
||||
# )
|
||||
# else:
|
||||
# output = torch.zeros(
|
||||
# (0, hidden_states.shape[1]),
|
||||
# device=hidden_states.device,
|
||||
# dtype=hidden_states.dtype,
|
||||
# )
|
||||
previous_event = Buffer.capture() if self.async_finish else None
|
||||
return output, previous_event
|
||||
|
||||
|
||||
@@ -57,6 +57,7 @@ from sglang.srt.layers.quantization.qoq import QoQConfig
|
||||
from sglang.srt.layers.quantization.w4afp8 import W4AFp8Config
|
||||
from sglang.srt.layers.quantization.w8a8_fp8 import W8A8Fp8Config
|
||||
from sglang.srt.layers.quantization.w8a8_int8 import W8A8Int8Config
|
||||
from sglang.srt.layers.quantization.slimquant_w4a8_marlin import SlimQuantW4A8Int8MarlinConfig
|
||||
from sglang.srt.utils import is_cuda, is_hip, mxfp_supported
|
||||
|
||||
_is_mxfp_supported = mxfp_supported()
|
||||
@@ -83,6 +84,7 @@ BASE_QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
|
||||
"w4afp8": W4AFp8Config,
|
||||
"petit_nvfp4": PetitNvFp4Config,
|
||||
"fbgemm_fp8": FBGEMMFp8Config,
|
||||
"slimquant_w4a8_marlin":SlimQuantW4A8Int8MarlinConfig,
|
||||
}
|
||||
|
||||
|
||||
|
||||
415
python/sglang/srt/layers/quantization/slimquant_w4a8.py
Normal file
415
python/sglang/srt/layers/quantization/slimquant_w4a8.py
Normal file
@@ -0,0 +1,415 @@
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
|
||||
import torch
|
||||
from sglang.srt.layers.linear import set_weight_attrs
|
||||
from sglang.srt.distributed import get_tensor_model_parallel_world_size
|
||||
from torch.nn.parameter import Parameter
|
||||
from sglang.srt.layers.linear import LinearBase
|
||||
from sglang.srt.layers.quantization.base_config import LinearMethodBase, QuantizationConfig, QuantizeMethodBase, FusedMoEMethodBase
|
||||
from sglang.srt.layers.parameter import (
|
||||
ChannelQuantScaleParameter,
|
||||
_ColumnvLLMParameter,
|
||||
RowvLLMParameter,
|
||||
)
|
||||
from lmslim.layers.gemm.int8_utils import (
|
||||
per_token_group_quant_int8,
|
||||
per_token_quant_int8)
|
||||
from sglang.srt import _custom_ops as ops
|
||||
from vllm.utils import W8a8GetCacheJSON
|
||||
from sglang.srt.layers.moe import MoeRunner, MoeRunnerBackend, MoeRunnerConfig
|
||||
|
||||
import os
|
||||
|
||||
class ModelWeightParameter(_ColumnvLLMParameter, RowvLLMParameter):
|
||||
"""
|
||||
Parameter class for linear layer weights. Uses both column and
|
||||
row parallelism.
|
||||
"""
|
||||
pass
|
||||
|
||||
W8A8_TRITONJSON=W8a8GetCacheJSON()
|
||||
|
||||
def baseline_scaled_mm(a: torch.Tensor,
|
||||
b: torch.Tensor,
|
||||
scale_a: torch.Tensor,
|
||||
scale_b: torch.Tensor,
|
||||
out_dtype: torch.dtype,
|
||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
|
||||
scales= scale_a* scale_b.T
|
||||
gemmout= torch.mm(
|
||||
a.to(dtype=torch.float32), b.to(dtype=torch.float32))
|
||||
output = (scales *gemmout).to(out_dtype)
|
||||
if bias is not None:
|
||||
output = output + bias
|
||||
return output.to(out_dtype)
|
||||
|
||||
|
||||
class SlimQuantW4A8Int8Config(QuantizationConfig):
|
||||
"""Config class for W8A8 Int8 Quantization.
|
||||
|
||||
- Weight: static, per-channel, symmetric
|
||||
- Activation: dynamic, per-token, symmetric
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
|
||||
return [torch.float16, torch.bfloat16]
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
return 75
|
||||
|
||||
@classmethod
|
||||
def get_name(self) -> str:
|
||||
return "slimquant_w4a8"
|
||||
|
||||
@classmethod
|
||||
def get_config_filenames(cls) -> List[str]:
|
||||
return []
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict[str, Any]) -> "SlimQuantW4A8Int8Config":
|
||||
return cls()
|
||||
|
||||
def get_quant_method(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
prefix: str,
|
||||
) -> Optional["QuantizeMethodBase"]:
|
||||
from sglang.srt.layers.moe.fused_moe_triton import (FusedMoE, FusedMoeWeightScaleSupported)
|
||||
|
||||
if isinstance(layer, LinearBase):
|
||||
return SlimQuantW4A8Int8LinearMethod(self)
|
||||
elif isinstance(layer, FusedMoE):
|
||||
return SlimQuantW4A8Int8MoEMethod(self)
|
||||
return None
|
||||
|
||||
def get_scaled_act_names(self) -> List[str]:
|
||||
return []
|
||||
|
||||
|
||||
class SlimQuantW4A8Int8LinearMethod(LinearMethodBase):
|
||||
|
||||
def __init__(self, quantization_config: SlimQuantW4A8Int8Config):
|
||||
self.quantization_config = quantization_config
|
||||
self.tritonsingleton= W8a8GetCacheJSON()
|
||||
self.w8a8_strategy=int(os.getenv('W8A8_SUPPORT_METHODS', '1'))
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
n=layer.weight.shape[0]
|
||||
k=layer.weight.shape[1]
|
||||
|
||||
if self.w8a8_strategy==1:
|
||||
if {n,k} not in self.tritonsingleton.weight_shapes:
|
||||
self.tritonsingleton.weight_shapes.append({n,k})
|
||||
json_file=self.tritonsingleton.get_w8a8json_name(n,k)
|
||||
configs_dict=self.tritonsingleton.get_triton_cache(json_file,n,k)
|
||||
|
||||
if configs_dict:
|
||||
self.tritonsingleton.triton_json_dict.update(configs_dict)
|
||||
|
||||
for key, value in configs_dict.items():
|
||||
m=int(key.split('_')[0])
|
||||
ops.triton_int8_gemm_helper(m=m,n=n,k=k,per_token_act_quant=True,per_out_channel_weight_quant=True,use_bias=False,device=layer.weight.device,best_config=value)
|
||||
else:
|
||||
weight_data=layer.weight.data
|
||||
_weight=weight_data.T.contiguous().reshape(n,-1)
|
||||
layer.weight.data=_weight
|
||||
|
||||
layer.weight = Parameter(layer.weight.t(), requires_grad=False)
|
||||
layer.weight_scale = Parameter(layer.weight_scale.data, requires_grad=False)
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
input_size_per_partition: int,
|
||||
output_partition_sizes: List[int],
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
params_dtype: torch.dtype,
|
||||
**extra_weight_attrs,
|
||||
):
|
||||
|
||||
weight_loader = extra_weight_attrs.get("weight_loader")
|
||||
self.logical_widths = output_partition_sizes
|
||||
|
||||
weight = ModelWeightParameter(
|
||||
data=torch.empty(
|
||||
sum(output_partition_sizes), input_size_per_partition, dtype=torch.int8
|
||||
),
|
||||
input_dim=1,
|
||||
output_dim=0,
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
layer.register_parameter("weight", weight)
|
||||
|
||||
weight_scale = ChannelQuantScaleParameter(
|
||||
data=torch.empty((sum(output_partition_sizes), 1), dtype=torch.float32),
|
||||
output_dim=0,
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
layer.register_parameter("weight_scale", weight_scale)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
input_quant_args: Optional[list[torch.Tensor]] = None,
|
||||
silu_quant_args: Optional[list[torch.Tensor]] = None
|
||||
):
|
||||
# if envs.USE_FUSED_RMS_QUANT and input_quant_args is not None:
|
||||
# assert len(input_quant_args) == 2
|
||||
# x_q, x_scale = input_quant_args
|
||||
# elif envs.USE_FUSED_SILU_MUL_QUANT and silu_quant_args is not None:
|
||||
# x_q, x_scale = silu_quant_args
|
||||
# else:
|
||||
x_q, x_scale = per_token_quant_int8(x)
|
||||
|
||||
if self.w8a8_strategy==1:
|
||||
m=x_q.shape[0]
|
||||
k=x_q.shape[1]
|
||||
n=layer.weight.shape[1]
|
||||
|
||||
if len(W8A8_TRITONJSON.triton_json_dict)==0:
|
||||
best_config=None
|
||||
|
||||
elif f"1_{n}_{k}" in W8A8_TRITONJSON.triton_json_dict:
|
||||
if m<=16:
|
||||
m_=m
|
||||
elif m<=64:
|
||||
m_= (m + 3) & -4 #取值到最近的4的倍数
|
||||
elif m<=160:
|
||||
m_=(m + 7) & -8
|
||||
|
||||
elif m<200: #256
|
||||
m_=160
|
||||
elif m<480: #512
|
||||
m_=256
|
||||
elif m<960: #1024
|
||||
m_=512
|
||||
elif m<2048:
|
||||
m_=1024
|
||||
elif m<4096:
|
||||
m_=2048
|
||||
elif m<6000:
|
||||
m_=4096
|
||||
else:
|
||||
m_=8192
|
||||
|
||||
best_config=W8A8_TRITONJSON.triton_json_dict[f"{m_}_{n}_{k}"]
|
||||
|
||||
else:
|
||||
best_config=None
|
||||
|
||||
#if best_config==None:
|
||||
# print("m:{},n:{},k:{}".format(m,n,k))
|
||||
# print("config not found!")
|
||||
|
||||
return ops.triton_scaled_mm(x_q,
|
||||
layer.weight,
|
||||
scale_a=x_scale,
|
||||
scale_b=layer.weight_scale,
|
||||
out_dtype=x.dtype,
|
||||
bias=bias,best_config=best_config)
|
||||
elif self.w8a8_strategy==2:
|
||||
return ops.cutlass_scaled_mm(x_q,
|
||||
layer.weight,
|
||||
scale_a=x_scale,
|
||||
scale_b=layer.weight_scale,
|
||||
out_dtype=x.dtype,
|
||||
bias=bias)
|
||||
else:
|
||||
return ops.rocblas_scaled_mm(x_q,
|
||||
layer.weight,
|
||||
scale_a=x_scale,
|
||||
scale_b=layer.weight_scale,
|
||||
out_dtype=x.dtype,
|
||||
bias=bias)
|
||||
|
||||
|
||||
class SlimQuantW4A8Int8MoEMethod:
|
||||
"""MoE method for W4A8INT8.
|
||||
Supports loading INT8 checkpoints with static weight scale and
|
||||
dynamic/static activation scale.
|
||||
Also supports loading quantized FP16/BF16 model checkpoints with dynamic
|
||||
activation scaling. The weight scaling factor will be initialized after
|
||||
the model weights are loaded.
|
||||
Args:
|
||||
quant_config: The quantization config.
|
||||
"""
|
||||
|
||||
def __new__(cls, *args, **kwargs):
|
||||
from sglang.srt.layers.moe.fused_moe_triton import (FusedMoE, FusedMoeWeightScaleSupported)
|
||||
|
||||
if not hasattr(cls, "_initialized"):
|
||||
original_init = cls.__init__
|
||||
new_cls = type(
|
||||
cls.__name__,
|
||||
(FusedMoEMethodBase,),
|
||||
{
|
||||
"__init__": original_init,
|
||||
**{k: v for k, v in cls.__dict__.items() if k != "__dict__"},
|
||||
},
|
||||
)
|
||||
obj = super(new_cls, new_cls).__new__(new_cls)
|
||||
obj.__init__(*args, **kwargs)
|
||||
return obj
|
||||
return super().__new__(cls)
|
||||
|
||||
def __init__(self, quant_config):
|
||||
self.quant_config = quant_config
|
||||
self.tritonsingleton= W8a8GetCacheJSON()
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
num_experts: int,
|
||||
hidden_size: int,
|
||||
intermediate_size: int,
|
||||
params_dtype: torch.dtype,
|
||||
**extra_weight_attrs,
|
||||
):
|
||||
from sglang.srt.layers.moe.fused_moe_triton import (FusedMoE, FusedMoeWeightScaleSupported)
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
|
||||
# WEIGHTS
|
||||
w13_weight = torch.nn.Parameter(
|
||||
torch.empty(
|
||||
num_experts, 2 * intermediate_size, hidden_size//2, dtype=torch.int8
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w13_weight", w13_weight)
|
||||
set_weight_attrs(w13_weight, extra_weight_attrs)
|
||||
|
||||
w2_weight = torch.nn.Parameter(
|
||||
torch.empty(num_experts, hidden_size, intermediate_size//2, dtype=torch.int8),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w2_weight", w2_weight)
|
||||
set_weight_attrs(w2_weight, extra_weight_attrs)
|
||||
|
||||
w13_weight_scale = torch.nn.Parameter(
|
||||
torch.ones(num_experts, 2 * intermediate_size, 1, dtype=torch.float32),
|
||||
requires_grad=False,
|
||||
)
|
||||
w2_weight_scale = torch.nn.Parameter(
|
||||
torch.ones(num_experts, hidden_size, 1, dtype=torch.float32),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w13_weight_scale", w13_weight_scale)
|
||||
layer.register_parameter("w2_weight_scale", w2_weight_scale)
|
||||
|
||||
extra_weight_attrs.update(
|
||||
{"quant_method": FusedMoeWeightScaleSupported.CHANNEL.value}
|
||||
)
|
||||
|
||||
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
|
||||
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
|
||||
|
||||
w13_input_scale = None
|
||||
layer.register_parameter("w13_input_scale", w13_input_scale)
|
||||
|
||||
w2_input_scale = None
|
||||
layer.register_parameter("w2_input_scale", w2_input_scale)
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
E=layer.w13_weight.shape[0]
|
||||
N1=layer.w13_weight.shape[1]
|
||||
N2=layer.w2_weight.shape[1]
|
||||
K=N1//2
|
||||
if [E,N1,N2,K] not in self.tritonsingleton.moe_weight_shapes:
|
||||
self.tritonsingleton.moe_weight_shapes.append([E,N1,N2,K])
|
||||
|
||||
TOPK= self.tritonsingleton.topk
|
||||
|
||||
json_file=self.tritonsingleton.get_moeint8json_name(E,N1,N2,K,TOPK,use_int4_w4a8=True)
|
||||
configs_dict=self.tritonsingleton.get_moeint8_triton_cache(json_file,E,N1,N2,K,TOPK)
|
||||
|
||||
#warmup
|
||||
if configs_dict:
|
||||
self.tritonsingleton.triton_moejson_dict.update(configs_dict)
|
||||
|
||||
layer.w13_weight = Parameter(layer.w13_weight, requires_grad=False)
|
||||
layer.w2_weight = Parameter(layer.w2_weight, requires_grad=False)
|
||||
layer.w13_weight_scale = Parameter(
|
||||
layer.w13_weight_scale.data, requires_grad=False
|
||||
)
|
||||
layer.w2_weight_scale = Parameter(
|
||||
layer.w2_weight_scale.data, requires_grad=False
|
||||
)
|
||||
|
||||
def create_moe_runner(
|
||||
self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
|
||||
):
|
||||
self.moe_runner_config = moe_runner_config
|
||||
self.runner = MoeRunner(MoeRunnerBackend.TRITON, moe_runner_config)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
renormalize: bool,
|
||||
use_grouped_topk: bool = False,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
scoring_func: str = "softmax",
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
activation: str = "silu",
|
||||
enable_eplb: bool = False,
|
||||
use_nn_moe: Optional[bool] = False,
|
||||
routed_scaling_factor: Optional[float] = None,
|
||||
use_fused_gate: Optional[bool] = False,
|
||||
**_
|
||||
) -> torch.Tensor:
|
||||
from sglang.srt.layers.moe.fused_moe_triton import (FusedMoE, FusedMoeWeightScaleSupported)
|
||||
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
|
||||
if enable_eplb:
|
||||
raise NotImplementedError(
|
||||
"EPLB not supported for `SlimQuantW4A8Int8MoEMethod` yet.")
|
||||
# Expert selection
|
||||
topk_weights, topk_ids = FusedMoE.select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
top_k=top_k,
|
||||
renormalize=renormalize,
|
||||
topk_group=topk_group,
|
||||
num_expert_group=num_expert_group,
|
||||
custom_routing_function=custom_routing_function,
|
||||
scoring_func=scoring_func,
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
routed_scaling_factor=routed_scaling_factor,
|
||||
use_fused_gate=use_fused_gate
|
||||
)
|
||||
|
||||
return fused_experts(
|
||||
x,
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
inplace=True,
|
||||
use_int4_w4a8=True,
|
||||
per_channel_quant=True,
|
||||
activation=activation,
|
||||
expert_map=expert_map,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
global_num_experts=global_num_experts,
|
||||
w1_scale=(layer.w13_weight_scale),
|
||||
w2_scale=(layer.w2_weight_scale),
|
||||
a1_scale=layer.w13_input_scale,
|
||||
a2_scale=layer.w2_input_scale,
|
||||
use_nn_moe=use_nn_moe,
|
||||
)
|
||||
319
python/sglang/srt/layers/quantization/slimquant_w4a8_marlin.py
Normal file
319
python/sglang/srt/layers/quantization/slimquant_w4a8_marlin.py
Normal file
@@ -0,0 +1,319 @@
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
# from sglang.srt.layers.moe.token_dispatcher.base import CombineInput
|
||||
|
||||
import torch
|
||||
from sglang.srt import _custom_ops as ops
|
||||
from sglang.srt.utils import set_weight_attrs
|
||||
from sglang.srt.distributed import get_tensor_model_parallel_world_size
|
||||
from torch.nn.parameter import Parameter
|
||||
from sglang.srt.layers.linear import LinearBase
|
||||
from sglang.srt.layers.quantization import QuantizationConfig
|
||||
from sglang.srt.layers.quantization.w4a8_utils import w4a8_weight_repack_impl
|
||||
from sglang.srt.layers.quantization.base_config import (FusedMoEMethodBase, QuantizeMethodBase)
|
||||
from sglang.srt.layers.quantization.slimquant_w4a8 import SlimQuantW4A8Int8LinearMethod
|
||||
from sglang.srt.layers.moe import MoeRunner, MoeRunnerBackend, MoeRunnerConfig
|
||||
|
||||
try:
|
||||
from lmslim.layers.fused_moe.fuse_moe_w4a8_marlin import fused_experts_impl_w4a8_marlin
|
||||
except Exception:
|
||||
print("INFO: Please install lmslim if you want to infer the quantitative model of moe.\n")
|
||||
|
||||
|
||||
class MarlinMoeWorkspace:
|
||||
"""
|
||||
Singleton manager for device-specific workspace buffers used by w4a8 Marlin-MoE.
|
||||
global_reduce_buffer will take 1.5MB * cus (about 120MB for BW200) memoery in each device
|
||||
"""
|
||||
_instances = {}
|
||||
def __new__(cls, device):
|
||||
if device not in cls._instances:
|
||||
instance = super().__new__(cls)
|
||||
instance._initialized = False
|
||||
cls._instances[device] = instance
|
||||
return cls._instances[device]
|
||||
|
||||
def __init__(self, device):
|
||||
if self._initialized:
|
||||
return
|
||||
sms = torch.cuda.get_device_properties(device).multi_processor_count
|
||||
self.workspace = torch.zeros(
|
||||
500, dtype=torch.int, device=device, requires_grad=False
|
||||
)
|
||||
self.global_reduce_buffer = torch.zeros(
|
||||
sms * 6 * 128 * 512, dtype=torch.int, device=device, requires_grad=False
|
||||
)
|
||||
self._initialized = True
|
||||
|
||||
def get_buffers(self):
|
||||
return self.workspace, self.global_reduce_buffer
|
||||
|
||||
def baseline_scaled_mm(a: torch.Tensor,
|
||||
b: torch.Tensor,
|
||||
scale_a: torch.Tensor,
|
||||
scale_b: torch.Tensor,
|
||||
out_dtype: torch.dtype,
|
||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
|
||||
scales= scale_a* scale_b.T
|
||||
gemmout= torch.mm(
|
||||
a.to(dtype=torch.float32), b.to(dtype=torch.float32))
|
||||
output = (scales *gemmout).to(out_dtype)
|
||||
if bias is not None:
|
||||
output = output + bias
|
||||
return output.to(out_dtype)
|
||||
|
||||
|
||||
class SlimQuantW4A8Int8MarlinConfig(QuantizationConfig):
|
||||
"""Config class for W4A8 Int8 Quantization.
|
||||
- Weight: static, per-channel, symmetric
|
||||
- Activation: dynamic, per-token, symmetric
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
|
||||
return [torch.float16, torch.bfloat16]
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
return 75
|
||||
|
||||
@classmethod
|
||||
def get_name(self) -> str:
|
||||
return "slimquant_w4a8_marlin"
|
||||
|
||||
@classmethod
|
||||
def get_config_filenames(cls) -> List[str]:
|
||||
return []
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict[str, Any]) -> "SlimQuantW4A8Int8MarlinConfig":
|
||||
return cls()
|
||||
@classmethod
|
||||
def override_quantization_method(
|
||||
cls, hf_quant_cfg, user_quant) -> Optional[str]:
|
||||
if hf_quant_cfg.get("quant_method") == "slimquant_w4a8" \
|
||||
and user_quant == "slimquant_w4a8_marlin":
|
||||
return cls.get_name()
|
||||
return None
|
||||
def get_quant_method(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
prefix: str,
|
||||
) -> Optional["QuantizeMethodBase"]:
|
||||
from sglang.srt.layers.moe.fused_moe_triton import (FusedMoE, FusedMoeWeightScaleSupported)
|
||||
|
||||
if isinstance(layer, LinearBase):
|
||||
return SlimQuantW4A8Int8LinearMethod(self)
|
||||
elif isinstance(layer, FusedMoE):
|
||||
return SlimQuantW4A8Int8MarlinMoEMethod(self)
|
||||
return None
|
||||
|
||||
def get_scaled_act_names(self) -> List[str]:
|
||||
return []
|
||||
|
||||
|
||||
class SlimQuantW4A8Int8MarlinMoEMethod:
|
||||
"""MoE method for W4A8INT8 Marlin.
|
||||
Supports loading INT8 checkpoints with static weight scale and
|
||||
dynamic/static activation scale.
|
||||
Args:
|
||||
quant_config: The quantization config.
|
||||
"""
|
||||
|
||||
def __new__(cls, *args, **kwargs):
|
||||
from sglang.srt.layers.moe.fused_moe_triton import (FusedMoE, FusedMoeWeightScaleSupported)
|
||||
|
||||
if not hasattr(cls, "_initialized"):
|
||||
original_init = cls.__init__
|
||||
new_cls = type(
|
||||
cls.__name__,
|
||||
(FusedMoEMethodBase,),
|
||||
{
|
||||
"__init__": original_init,
|
||||
**{k: v for k, v in cls.__dict__.items() if k != "__dict__"},
|
||||
},
|
||||
)
|
||||
obj = super(new_cls, new_cls).__new__(new_cls)
|
||||
obj.__init__(*args, **kwargs)
|
||||
return obj
|
||||
return super().__new__(cls)
|
||||
|
||||
def __init__(self, quant_config):
|
||||
self.quant_config = quant_config
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
num_experts: int,
|
||||
hidden_size: int,
|
||||
intermediate_size_per_partition: int,
|
||||
params_dtype: torch.dtype,
|
||||
**extra_weight_attrs,
|
||||
):
|
||||
from sglang.srt.layers.moe.fused_moe_triton import (FusedMoE, FusedMoeWeightScaleSupported)
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
intermediate_size = intermediate_size_per_partition
|
||||
# WEIGHTS
|
||||
w13_weight = torch.nn.Parameter(
|
||||
torch.empty(
|
||||
num_experts, 2 * intermediate_size, hidden_size//2, dtype=torch.int8
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w13_weight", w13_weight)
|
||||
set_weight_attrs(w13_weight, extra_weight_attrs)
|
||||
|
||||
w2_weight = torch.nn.Parameter(
|
||||
torch.empty(num_experts, hidden_size, intermediate_size//2, dtype=torch.int8),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w2_weight", w2_weight)
|
||||
set_weight_attrs(w2_weight, extra_weight_attrs)
|
||||
|
||||
w13_weight_scale = torch.nn.Parameter(
|
||||
torch.ones(num_experts, 2 * intermediate_size, 1, dtype=torch.float32),
|
||||
requires_grad=False,
|
||||
)
|
||||
w2_weight_scale = torch.nn.Parameter(
|
||||
torch.ones(num_experts, hidden_size, 1, dtype=torch.float32),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w13_weight_scale", w13_weight_scale)
|
||||
layer.register_parameter("w2_weight_scale", w2_weight_scale)
|
||||
|
||||
extra_weight_attrs.update(
|
||||
{"quant_method": FusedMoeWeightScaleSupported.CHANNEL.value}
|
||||
)
|
||||
|
||||
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
|
||||
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
|
||||
|
||||
w13_input_scale = None
|
||||
layer.register_parameter("w13_input_scale", w13_input_scale)
|
||||
|
||||
w2_input_scale = None
|
||||
layer.register_parameter("w2_input_scale", w2_input_scale)
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
layer.w13_weight_scale = Parameter(
|
||||
layer.w13_weight_scale.data, requires_grad=False
|
||||
)
|
||||
layer.w2_weight_scale = Parameter(
|
||||
layer.w2_weight_scale.data, requires_grad=False
|
||||
)
|
||||
|
||||
layer.w13_weight = Parameter(w4a8_weight_repack_impl(layer.w13_weight), requires_grad=False)
|
||||
layer.w2_weight = Parameter(w4a8_weight_repack_impl(layer.w2_weight), requires_grad=False)
|
||||
|
||||
def create_moe_runner(
|
||||
self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
|
||||
):
|
||||
self.moe_runner_config = moe_runner_config
|
||||
self.runner = MoeRunner(MoeRunnerBackend.TRITON, moe_runner_config)
|
||||
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
dispatch_output,
|
||||
) :
|
||||
from sglang.srt.layers.moe.token_dispatcher.standard import StandardCombineInput
|
||||
x = dispatch_output.hidden_states
|
||||
topk_output = dispatch_output.topk_output
|
||||
from sglang.srt.layers.moe.topk import apply_topk_weights_cpu
|
||||
|
||||
topk_weights, topk_ids, _ = topk_output
|
||||
x, topk_weights = apply_topk_weights_cpu(
|
||||
self.moe_runner_config.apply_router_weight_on_input, topk_weights, x
|
||||
)
|
||||
workspace, global_reduce_buffer = MarlinMoeWorkspace(x.device).get_buffers()
|
||||
output = fused_experts_impl_w4a8_marlin(
|
||||
x,
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
workspace=workspace,
|
||||
global_reduce_buffer=global_reduce_buffer,
|
||||
inplace=True,
|
||||
use_int4_w4a8=True,
|
||||
per_channel_quant=True,
|
||||
activation=layer.moe_runner_config.activation,
|
||||
# expert_map=layer.expert_map_gpu,
|
||||
apply_router_weight_on_input=self.moe_runner_config.apply_router_weight_on_input,
|
||||
global_num_experts=layer.moe_runner_config.num_experts,
|
||||
w1_scale=(layer.w13_weight_scale),
|
||||
w2_scale=(layer.w2_weight_scale),
|
||||
a1_scale=layer.w13_input_scale,
|
||||
a2_scale=layer.w2_input_scale,
|
||||
use_nn_moe=False,
|
||||
)
|
||||
return StandardCombineInput(hidden_states=output)
|
||||
# def _apply(
|
||||
# self,
|
||||
# layer: torch.nn.Module,
|
||||
# x: torch.Tensor,
|
||||
# router_logits: torch.Tensor,
|
||||
# top_k: int,
|
||||
# #renormalize: bool,
|
||||
# #use_grouped_topk: bool = False,
|
||||
# topk_group: Optional[int] = None,
|
||||
# num_expert_group: Optional[int] = None,
|
||||
# global_num_experts: int = -1,
|
||||
# expert_map: Optional[torch.Tensor] = None,
|
||||
# custom_routing_function: Optional[Callable] = None,
|
||||
# scoring_func: str = "softmax",
|
||||
# e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||
# apply_router_weight_on_input: bool = False,
|
||||
# activation: str = "silu",
|
||||
# enable_eplb: bool = False,
|
||||
# use_nn_moe: Optional[bool] = False,
|
||||
# routed_scaling_factor: Optional[float] = None,
|
||||
# use_fused_gate: Optional[bool] = False,
|
||||
# **_
|
||||
# ) -> torch.Tensor:
|
||||
# from sglang.srt.layers.moe.fused_moe_triton import (FusedMoE, FusedMoeWeightScaleSupported)
|
||||
# from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
|
||||
# if enable_eplb:
|
||||
# raise NotImplementedError(
|
||||
# "EPLB not supported for `SlimQuantW4A8Int8MarlinMoEMethod` yet.")
|
||||
# # Expert selection
|
||||
# topk_weights, topk_ids = FusedMoE.select_experts(
|
||||
# hidden_states=x,
|
||||
# router_logits=router_logits,
|
||||
# #use_grouped_topk=use_grouped_topk,
|
||||
# top_k=top_k,
|
||||
# #renormalize=renormalize,
|
||||
# topk_group=topk_group,
|
||||
# num_expert_group=num_expert_group,
|
||||
# custom_routing_function=custom_routing_function,
|
||||
# scoring_func=scoring_func,
|
||||
# e_score_correction_bias=e_score_correction_bias,
|
||||
# routed_scaling_factor=routed_scaling_factor,
|
||||
# use_fused_gate=use_fused_gate
|
||||
# )
|
||||
# workspace, global_reduce_buffer = MarlinMoeWorkspace(x.device).get_buffers()
|
||||
# return fused_experts_impl_w4a8_marlin(
|
||||
# x,
|
||||
# layer.w13_weight,
|
||||
# layer.w2_weight,
|
||||
# topk_weights=topk_weights,
|
||||
# topk_ids=topk_ids,
|
||||
# workspace=workspace,
|
||||
# global_reduce_buffer=global_reduce_buffer,
|
||||
# inplace=True,
|
||||
# use_int4_w4a8=True,
|
||||
# per_channel_quant=True,
|
||||
# activation=activation,
|
||||
# expert_map=expert_map,
|
||||
# apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
# global_num_experts=global_num_experts,
|
||||
# w1_scale=(layer.w13_weight_scale),
|
||||
# w2_scale=(layer.w2_weight_scale),
|
||||
# a1_scale=layer.w13_input_scale,
|
||||
# a2_scale=layer.w2_input_scale,
|
||||
# use_nn_moe=use_nn_moe,
|
||||
# )
|
||||
92
python/sglang/srt/layers/quantization/w4a8_utils.py
Normal file
92
python/sglang/srt/layers/quantization/w4a8_utils.py
Normal file
@@ -0,0 +1,92 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
try:
|
||||
from lightop import awq_marlin_repack_w4a8
|
||||
use_lightop = False
|
||||
except Exception:
|
||||
use_lightop = False
|
||||
|
||||
def unpack_int8_to_int4(tensor_int8: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
将[N, K//2]大小的torch.int8 Tensor,转换为[N, K]大小的torch.int32 Tensor。
|
||||
每个int8包含两个int4,分别提取到int32的低4位,其余位为0。
|
||||
|
||||
Args:
|
||||
tensor_int8 (torch.Tensor): 输入张量,形状为[N, K//2],类型为torch.int8。
|
||||
|
||||
Returns:
|
||||
torch.Tensor: 输出张量,形状为[N, K],类型为torch.int32。
|
||||
"""
|
||||
if tensor_int8.dtype != torch.int8:
|
||||
raise ValueError("Input tensor must be of type torch.int8")
|
||||
|
||||
N, K_half = tensor_int8.shape
|
||||
tensor_uint8 = tensor_int8.to(torch.uint8)
|
||||
high4 = tensor_uint8 & 0x0F
|
||||
low4 = (tensor_uint8 >> 4) & 0x0F
|
||||
unpacked = torch.empty((N, K_half * 2), dtype=torch.int32, device=tensor_int8.device)
|
||||
unpacked[:, 0::2] = low4.to(torch.int32)
|
||||
unpacked[:, 1::2] = high4.to(torch.int32)
|
||||
|
||||
return unpacked
|
||||
|
||||
def get_weight_perms(interleave: bool=True):
|
||||
perm = []
|
||||
for i in range(64):
|
||||
|
||||
for col in range(4):
|
||||
cur_col = (i % 16) * 4 + col
|
||||
for row in range(8):
|
||||
cur_row = (i // 16) * 8 + row
|
||||
cur_idx = cur_row * 64 + cur_col
|
||||
perm.append(cur_idx)
|
||||
|
||||
perm = np.array(perm)
|
||||
if interleave:
|
||||
interleave = np.array([4, 0, 5, 1, 6, 2, 7, 3])
|
||||
perm = perm.reshape((-1, 8))[:, interleave].ravel()
|
||||
|
||||
perm = torch.from_numpy(perm)
|
||||
|
||||
return perm
|
||||
|
||||
def marlin_weights(q_w,weight_perm,k_tile=32,n_tile=64,pack_factor=8):
|
||||
size_k, size_n = q_w.shape
|
||||
q_w = q_w.reshape((size_k // k_tile, k_tile, size_n // n_tile, n_tile))
|
||||
q_w = q_w.permute((0, 2, 1, 3))
|
||||
q_w = q_w.reshape((size_k // k_tile, size_n * k_tile))
|
||||
q_w = q_w.reshape((-1, weight_perm.numel()))[:, weight_perm].reshape(q_w.shape)
|
||||
|
||||
orig_device = q_w.device
|
||||
q_w = q_w.contiguous().to(torch.int32)
|
||||
M, N = q_w.shape
|
||||
assert N % pack_factor == 0, f"size_n ({N}) must be divisible by pack_factor ({pack_factor})"
|
||||
q_packed = torch.zeros((M, N // pack_factor), dtype=torch.int32, device=orig_device)
|
||||
for i in range(pack_factor):
|
||||
q_packed += q_w[:, i::pack_factor] << (4 * i)
|
||||
|
||||
return q_packed
|
||||
|
||||
def w4a8_2_marlin_weight(w4a8_w):
|
||||
full_w4a8_w = unpack_int8_to_int4(w4a8_w)
|
||||
full_w4a8_w = full_w4a8_w.T
|
||||
weight_perm = get_weight_perms()
|
||||
marlin_q_w = marlin_weights(full_w4a8_w, weight_perm, k_tile=32, n_tile=64, pack_factor=8)
|
||||
return marlin_q_w
|
||||
|
||||
def w4a8_weight_repack_impl(input):
|
||||
if use_lightop:
|
||||
size_batch = input.shape[0]
|
||||
size_n = input.shape[1]
|
||||
size_k = input.shape[2] * 2
|
||||
output = torch.zeros((size_batch, size_k // 32, size_n * 4), device=input.device, dtype=torch.int32)
|
||||
awq_marlin_repack_w4a8(input, output, size_batch, size_k, size_n)
|
||||
else:
|
||||
w_marlin_list = []
|
||||
for e in range(input.shape[0]):
|
||||
w_marlin_in = w4a8_2_marlin_weight(input[e])
|
||||
w_marlin_list.append(w_marlin_in)
|
||||
output = torch.stack(w_marlin_list, dim=0)
|
||||
|
||||
return output
|
||||
@@ -22,7 +22,8 @@ from sglang.srt.layers.quantization.base_config import (
|
||||
QuantizeMethodBase,
|
||||
)
|
||||
from sglang.srt.layers.quantization.compressed_tensors.utils import should_ignore_layer
|
||||
from sglang.srt.layers.quantization.int8_kernel import per_token_quant_int8
|
||||
# from sglang.srt.layers.quantization.int8_kernel import per_token_quant_int8
|
||||
from lmslim.layers.gemm.int8_utils import per_token_quant_int8
|
||||
from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod
|
||||
from sglang.srt.utils import (
|
||||
apply_module_patch,
|
||||
@@ -39,6 +40,8 @@ if TYPE_CHECKING:
|
||||
CombineInput,
|
||||
StandardDispatchOutput,
|
||||
)
|
||||
from lmslim import quant_ops
|
||||
|
||||
|
||||
_is_cuda = is_cuda()
|
||||
_is_cpu_amx_available = cpu_has_amx_support()
|
||||
@@ -405,7 +408,7 @@ class W8A8Int8LinearMethod(LinearMethodBase):
|
||||
x_scale_2d = x_scale.view(-1, x_scale.shape[-1])
|
||||
output_shape = [*x_q.shape[:-1], layer.weight.shape[1]]
|
||||
|
||||
output = int8_scaled_mm(
|
||||
output = quant_ops.triton_scaled_mm(
|
||||
x_q_2d,
|
||||
layer.weight,
|
||||
x_scale_2d,
|
||||
|
||||
@@ -1618,7 +1618,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
||||
self.seq_lens_cpu = self.seq_lens_cpu[keep_indices]
|
||||
self.orig_seq_lens = self.orig_seq_lens[keep_indices_device]
|
||||
self.out_cache_loc = None
|
||||
self.seq_lens_sum = self.seq_lens.sum().item()
|
||||
self.seq_lens_sum = self.seq_lens.sum()
|
||||
self.output_ids = self.output_ids[keep_indices_device]
|
||||
self.return_logprob = any(req.return_logprob for req in self.reqs)
|
||||
if self.return_logprob:
|
||||
|
||||
@@ -203,7 +203,7 @@ _is_xpu_xmx_available = xpu_has_xmx_support()
|
||||
SGLANG_CI_SMALL_KV_SIZE = os.getenv("SGLANG_CI_SMALL_KV_SIZE", None)
|
||||
|
||||
# Detect stragger ranks in model loading
|
||||
UNBALANCED_MODEL_LOADING_TIMEOUT_S = 300
|
||||
UNBALANCED_MODEL_LOADING_TIMEOUT_S = 3600
|
||||
|
||||
# the ratio of mamba cache pool size to max_running_requests, it will be safe when it is larger than 2 (yizhang2077)
|
||||
MAMBA_CACHE_SIZE_MAX_RUNNING_REQUESTS_RATIO = 3
|
||||
|
||||
58
python/sglang/srt/profile/prof.py
Normal file
58
python/sglang/srt/profile/prof.py
Normal file
@@ -0,0 +1,58 @@
|
||||
from ctypes import *
|
||||
import os
|
||||
import time
|
||||
import threading
|
||||
|
||||
class Prof:
|
||||
def __init__(self):
|
||||
self.use_roctx = os.getenv('SGLANG_HIP_PROF') is not None
|
||||
if self.use_roctx:
|
||||
self.lib = cdll.LoadLibrary("libroctracer64.so")
|
||||
self.lib.roctxRangePushA.argtypes = [c_char_p]
|
||||
self.lib.roctxRangePushA.restype = c_int
|
||||
self.lib.roctxRangePop.restype = c_int
|
||||
self.tm = time.perf_counter()
|
||||
self.push_depth = {}
|
||||
|
||||
def StartTracer(self):
|
||||
if self.use_roctx:
|
||||
if self.lib is None:
|
||||
self.lib = cdll.LoadLibrary("libroctracer64.so")
|
||||
self.lib.roctracer_start()
|
||||
self.roc_tracer_flag = True
|
||||
|
||||
def StopTracer(self):
|
||||
if self.use_roctx:
|
||||
if self.lib is None:
|
||||
self.lib = cdll.LoadLibrary("libroctracer64.so")
|
||||
self.lib.roctracer_stop()
|
||||
self.roc_tracer_flag = False
|
||||
|
||||
def thread_depth_add(self, num):
|
||||
current_thread = threading.current_thread()
|
||||
thread_id = current_thread.ident
|
||||
if thread_id not in self.push_depth.keys():
|
||||
self.push_depth[thread_id] = 0
|
||||
if num < 0 and self.push_depth[thread_id] == 0:
|
||||
return False
|
||||
self.push_depth[thread_id] += num
|
||||
return True
|
||||
|
||||
def ProfRangePush(self, message):
|
||||
if profile.use_roctx and self.roc_tracer_flag:
|
||||
profile.lib.roctxRangePushA(message.encode('utf-8'))
|
||||
profile.lib.roctxRangePushA(message.encode('utf-8'))
|
||||
self.thread_depth_add(1)
|
||||
|
||||
def ProfRangePop(self):
|
||||
if profile.use_roctx and self.roc_tracer_flag:
|
||||
if not self.thread_depth_add(-1):
|
||||
return
|
||||
profile.lib.roctxRangePop()
|
||||
|
||||
def ProfRangeAutoPush(self, message):
|
||||
self.ProfRangePop()
|
||||
self.ProfRangePush(message)
|
||||
|
||||
|
||||
profile = Prof()
|
||||
@@ -93,6 +93,7 @@ QUANTIZATION_CHOICES = [
|
||||
"w4afp8",
|
||||
"mxfp4",
|
||||
"compressed-tensors", # for Ktransformers
|
||||
"slimquant_w4a8_marlin",
|
||||
]
|
||||
|
||||
ATTENTION_BACKEND_CHOICES = [
|
||||
|
||||
@@ -25,7 +25,7 @@
|
||||
#define INTRIN_M 16
|
||||
#define INTRIN_N 16
|
||||
#define INTRIN_K 32
|
||||
#define WARP_SIZE 32
|
||||
#define WARP_SIZE 64
|
||||
#define SMEM_PAD_A 0
|
||||
#define SMEM_PAD_B 0
|
||||
#define PACK_SIZE 16
|
||||
|
||||
@@ -25,7 +25,7 @@
|
||||
#define INTRIN_M 16
|
||||
#define INTRIN_N 16
|
||||
#define INTRIN_K 32
|
||||
#define WARP_SIZE 32
|
||||
#define WARP_SIZE 64
|
||||
#define SMEM_PAD_A 0
|
||||
#define SMEM_PAD_B 0
|
||||
#define PACK_SIZE 16
|
||||
|
||||
@@ -5,7 +5,7 @@
|
||||
#include <cstdint>
|
||||
|
||||
#ifndef USE_ROCM
|
||||
#define WARP_SIZE 32
|
||||
#define WARP_SIZE 64
|
||||
#include "pytorch_extension_utils.h"
|
||||
#else
|
||||
#include "pytorch_extension_utils_rocm.h"
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
// copied from https://github.com/ggerganov/llama.cpp/blob/b2899/ggml-common.h
|
||||
#define QK_K 256
|
||||
#define K_QUANTS_PER_ITERATION 2
|
||||
#define WARP_SIZE_GGUF 32
|
||||
#define WARP_SIZE_GGUF 64
|
||||
#define K_SCALE_SIZE 12
|
||||
#define CUDA_DEQUANTIZE_BLOCK_SIZE 256
|
||||
#define CUDA_QUANTIZE_BLOCK_SIZE 256
|
||||
|
||||
@@ -340,7 +340,7 @@ inline bool getEnvEnablePDL() {
|
||||
#define CEILDIV(x, y) (((x) + (y) - 1) / (y))
|
||||
|
||||
#ifndef USE_ROCM
|
||||
#define WARP_SIZE 32
|
||||
#define WARP_SIZE 64
|
||||
#else
|
||||
#if defined(__GFX9__) || !defined(__HIP_DEVICE_COMPILE__)
|
||||
#define WARP_SIZE 64
|
||||
|
||||
Reference in New Issue
Block a user