Add initial support for gpt-oss (#8824)
This commit is contained in:
@@ -88,6 +88,7 @@ class TritonAttnBackend(AttentionBackend):
|
||||
self.window_kv_indptr = torch.zeros_like(kv_indptr_buf)
|
||||
|
||||
self.req_to_token = model_runner.req_to_token_pool.req_to_token
|
||||
self.token_to_kv_pool_allocator = model_runner.token_to_kv_pool_allocator
|
||||
|
||||
if not self.skip_prefill:
|
||||
self.qo_indptr = torch.zeros(
|
||||
@@ -197,6 +198,7 @@ class TritonAttnBackend(AttentionBackend):
|
||||
forward_batch.req_pool_indices,
|
||||
bs,
|
||||
self.device,
|
||||
self.token_to_kv_pool_allocator,
|
||||
)
|
||||
)
|
||||
window_num_kv_splits = torch.empty(
|
||||
@@ -225,7 +227,6 @@ class TritonAttnBackend(AttentionBackend):
|
||||
mask_indptr = None
|
||||
max_extend_len = None
|
||||
elif forward_batch.forward_mode.is_target_verify():
|
||||
# TODO: Support sliding window in spec inference
|
||||
bs = len(forward_batch.req_pool_indices)
|
||||
qo_indptr = torch.arange(
|
||||
0,
|
||||
@@ -250,6 +251,20 @@ class TritonAttnBackend(AttentionBackend):
|
||||
self.req_to_token.stride(0),
|
||||
)
|
||||
|
||||
if self.sliding_window_size is not None and self.sliding_window_size > 0:
|
||||
window_kv_indptr, window_kv_indices, window_kv_lens = (
|
||||
update_sliding_window_buffer(
|
||||
self.window_kv_indptr,
|
||||
self.req_to_token,
|
||||
self.sliding_window_size,
|
||||
forward_batch.seq_lens,
|
||||
forward_batch.req_pool_indices,
|
||||
bs,
|
||||
self.device,
|
||||
self.token_to_kv_pool_allocator,
|
||||
)
|
||||
)
|
||||
|
||||
custom_mask = spec_info.custom_mask
|
||||
seq_mask_len = self.num_draft_tokens * (
|
||||
forward_batch.seq_lens + self.num_draft_tokens
|
||||
@@ -308,6 +323,7 @@ class TritonAttnBackend(AttentionBackend):
|
||||
forward_batch.req_pool_indices,
|
||||
bs,
|
||||
self.device,
|
||||
self.token_to_kv_pool_allocator,
|
||||
)
|
||||
|
||||
qo_indptr = self.qo_indptr
|
||||
@@ -423,14 +439,17 @@ class TritonAttnBackend(AttentionBackend):
|
||||
):
|
||||
window_kv_indices = self.cuda_graph_window_kv_indices
|
||||
window_num_kv_splits = self.cuda_graph_window_num_kv_splits
|
||||
window_kv_indptr, _ = update_sliding_window_buffer_cuda_graph(
|
||||
self.window_kv_indptr,
|
||||
window_kv_indices,
|
||||
self.req_to_token,
|
||||
self.sliding_window_size,
|
||||
seq_lens[:bs],
|
||||
req_pool_indices,
|
||||
bs,
|
||||
window_kv_indptr, window_kv_indices, _ = (
|
||||
update_sliding_window_buffer_cuda_graph(
|
||||
self.window_kv_indptr,
|
||||
window_kv_indices,
|
||||
self.req_to_token,
|
||||
self.sliding_window_size,
|
||||
seq_lens[:bs],
|
||||
req_pool_indices,
|
||||
bs,
|
||||
self.token_to_kv_pool_allocator,
|
||||
)
|
||||
)
|
||||
else:
|
||||
kv_indptr, kv_indices = spec_info.kv_indptr, spec_info.kv_indices
|
||||
@@ -464,6 +483,22 @@ class TritonAttnBackend(AttentionBackend):
|
||||
self.req_to_token.stride(0),
|
||||
)
|
||||
|
||||
if self.sliding_window_size is not None and self.sliding_window_size > 0:
|
||||
window_kv_indices = self.cuda_graph_window_kv_indices
|
||||
window_num_kv_splits = self.cuda_graph_window_num_kv_splits
|
||||
window_kv_indptr, window_kv_indices, _ = (
|
||||
update_sliding_window_buffer_cuda_graph(
|
||||
self.window_kv_indptr,
|
||||
window_kv_indices,
|
||||
self.req_to_token,
|
||||
self.sliding_window_size,
|
||||
seq_lens,
|
||||
req_pool_indices,
|
||||
bs,
|
||||
self.token_to_kv_pool_allocator,
|
||||
)
|
||||
)
|
||||
|
||||
custom_mask = self.cuda_graph_custom_mask
|
||||
custom_mask[: spec_info.custom_mask.shape[0]] = spec_info.custom_mask
|
||||
seq_mask_len = self.num_draft_tokens * (seq_lens + self.num_draft_tokens)
|
||||
@@ -557,7 +592,7 @@ class TritonAttnBackend(AttentionBackend):
|
||||
):
|
||||
window_num_kv_splits = self.cuda_graph_window_num_kv_splits
|
||||
window_kv_indices = self.cuda_graph_window_kv_indices
|
||||
_, window_kv_lens = update_sliding_window_buffer_cuda_graph(
|
||||
_, _, window_kv_lens = update_sliding_window_buffer_cuda_graph(
|
||||
self.window_kv_indptr,
|
||||
window_kv_indices,
|
||||
self.req_to_token,
|
||||
@@ -565,6 +600,7 @@ class TritonAttnBackend(AttentionBackend):
|
||||
seq_lens[:bs],
|
||||
req_pool_indices[:bs],
|
||||
bs,
|
||||
self.token_to_kv_pool_allocator,
|
||||
)
|
||||
self.get_num_kv_splits(
|
||||
window_num_kv_splits[:num_token], window_kv_lens[:bs]
|
||||
@@ -599,6 +635,19 @@ class TritonAttnBackend(AttentionBackend):
|
||||
kv_indices,
|
||||
self.req_to_token.stride(0),
|
||||
)
|
||||
if self.sliding_window_size is not None and self.sliding_window_size > 0:
|
||||
window_num_kv_splits = self.cuda_graph_window_num_kv_splits
|
||||
window_kv_indices = self.cuda_graph_window_kv_indices
|
||||
_, _, window_kv_lens = update_sliding_window_buffer_cuda_graph(
|
||||
self.window_kv_indptr,
|
||||
window_kv_indices,
|
||||
self.req_to_token,
|
||||
self.sliding_window_size,
|
||||
seq_lens,
|
||||
req_pool_indices,
|
||||
bs,
|
||||
self.token_to_kv_pool_allocator,
|
||||
)
|
||||
custom_mask = self.cuda_graph_custom_mask
|
||||
custom_mask[: spec_info.custom_mask.shape[0]] = spec_info.custom_mask
|
||||
seq_mask_len = self.num_draft_tokens * (seq_lens + self.num_draft_tokens)
|
||||
@@ -637,6 +686,7 @@ class TritonAttnBackend(AttentionBackend):
|
||||
layer: RadixAttention,
|
||||
forward_batch: ForwardBatch,
|
||||
save_kv_cache=True,
|
||||
sk=None,
|
||||
):
|
||||
# TODO: reuse the buffer across layers
|
||||
if layer.qk_head_dim != layer.v_head_dim:
|
||||
@@ -680,7 +730,8 @@ class TritonAttnBackend(AttentionBackend):
|
||||
self.forward_metadata.max_extend_len,
|
||||
layer.scaling,
|
||||
layer.logit_cap,
|
||||
sliding_window_size,
|
||||
sliding_window_size=sliding_window_size,
|
||||
sk=sk,
|
||||
)
|
||||
return o
|
||||
|
||||
@@ -692,6 +743,7 @@ class TritonAttnBackend(AttentionBackend):
|
||||
layer: RadixAttention,
|
||||
forward_batch: ForwardBatch,
|
||||
save_kv_cache=True,
|
||||
sk=None,
|
||||
):
|
||||
# During torch.compile, there is a bug in rotary_emb that causes the
|
||||
# output value to have a 3D tensor shape. This reshapes the output correctly.
|
||||
@@ -728,6 +780,7 @@ class TritonAttnBackend(AttentionBackend):
|
||||
self.max_kv_splits,
|
||||
layer.scaling,
|
||||
layer.logit_cap,
|
||||
sk=sk,
|
||||
)
|
||||
return o
|
||||
|
||||
@@ -932,10 +985,11 @@ def update_sliding_window_buffer(
|
||||
req_pool_indices,
|
||||
bs,
|
||||
device,
|
||||
token_to_kv_pool_allocator=None,
|
||||
):
|
||||
window_kv_lens = torch.minimum(
|
||||
seq_lens,
|
||||
torch.tensor(sliding_window_size + 1),
|
||||
torch.tensor(sliding_window_size),
|
||||
)
|
||||
window_kv_indptr[1 : bs + 1] = torch.cumsum(window_kv_lens, dim=0)
|
||||
window_kv_indptr = window_kv_indptr[: bs + 1]
|
||||
@@ -952,6 +1006,14 @@ def update_sliding_window_buffer(
|
||||
window_kv_indices,
|
||||
req_to_token.stride(0),
|
||||
)
|
||||
# full to swa index mapping
|
||||
if hasattr(token_to_kv_pool_allocator, "translate_loc_from_full_to_swa"):
|
||||
kv_last_index = window_kv_indptr[-1]
|
||||
window_kv_indices[:kv_last_index] = (
|
||||
token_to_kv_pool_allocator.translate_loc_from_full_to_swa(
|
||||
window_kv_indices[:kv_last_index]
|
||||
)
|
||||
)
|
||||
return window_kv_indptr, window_kv_indices, window_kv_lens
|
||||
|
||||
|
||||
@@ -963,10 +1025,11 @@ def update_sliding_window_buffer_cuda_graph(
|
||||
seq_lens,
|
||||
req_pool_indices,
|
||||
bs,
|
||||
token_to_kv_pool_allocator=None,
|
||||
):
|
||||
window_kv_lens = torch.minimum(
|
||||
seq_lens,
|
||||
torch.tensor(sliding_window_size + 1),
|
||||
torch.tensor(sliding_window_size),
|
||||
)
|
||||
window_kv_indptr[1 : bs + 1] = torch.cumsum(window_kv_lens, dim=0)
|
||||
window_kv_indptr = window_kv_indptr[: bs + 1]
|
||||
@@ -980,4 +1043,12 @@ def update_sliding_window_buffer_cuda_graph(
|
||||
window_kv_indices,
|
||||
req_to_token.stride(0),
|
||||
)
|
||||
return window_kv_indptr, window_kv_lens
|
||||
# full to swa index mapping
|
||||
if hasattr(token_to_kv_pool_allocator, "translate_loc_from_full_to_swa"):
|
||||
kv_last_index = window_kv_indptr[-1]
|
||||
window_kv_indices[:kv_last_index] = (
|
||||
token_to_kv_pool_allocator.translate_loc_from_full_to_swa(
|
||||
window_kv_indices[:kv_last_index]
|
||||
)
|
||||
)
|
||||
return window_kv_indptr, window_kv_indices, window_kv_lens
|
||||
|
||||
@@ -495,6 +495,7 @@ def _fwd_kernel_stage2(
|
||||
O,
|
||||
kv_indptr,
|
||||
num_kv_splits,
|
||||
sk_ptr,
|
||||
stride_mid_ob,
|
||||
stride_mid_oh,
|
||||
stride_mid_os,
|
||||
@@ -504,6 +505,7 @@ def _fwd_kernel_stage2(
|
||||
MIN_BLOCK_KV: tl.constexpr,
|
||||
BLOCK_DV: tl.constexpr,
|
||||
Lv: tl.constexpr,
|
||||
HAS_SK: tl.constexpr,
|
||||
):
|
||||
cur_batch = tl.program_id(0)
|
||||
cur_head = tl.program_id(1)
|
||||
@@ -545,6 +547,10 @@ def _fwd_kernel_stage2(
|
||||
e_sum = e_sum * old_scale + exp_logic
|
||||
e_max = n_e_max
|
||||
|
||||
if HAS_SK:
|
||||
cur_sk = tl.load(sk_ptr + cur_head)
|
||||
e_sum += tl.exp(cur_sk - e_max)
|
||||
|
||||
tl.store(
|
||||
O + cur_batch * stride_obs + cur_head * stride_oh + offs_d,
|
||||
acc / e_sum,
|
||||
@@ -561,12 +567,14 @@ def _decode_softmax_reducev_fwd(
|
||||
kv_indptr,
|
||||
num_kv_splits,
|
||||
max_kv_splits,
|
||||
sk=None,
|
||||
):
|
||||
batch, head_num = q.shape[0], q.shape[1]
|
||||
Lv = v_buffer.shape[-1]
|
||||
BLOCK_DV = triton.next_power_of_2(Lv)
|
||||
|
||||
MAX_KV_SPLITS = max_kv_splits
|
||||
HAS_SK = sk is not None
|
||||
|
||||
extra_kargs = {}
|
||||
if _is_hip:
|
||||
@@ -581,6 +589,7 @@ def _decode_softmax_reducev_fwd(
|
||||
o,
|
||||
kv_indptr,
|
||||
num_kv_splits,
|
||||
sk,
|
||||
logits.stride(0),
|
||||
logits.stride(1),
|
||||
logits.stride(2),
|
||||
@@ -590,6 +599,7 @@ def _decode_softmax_reducev_fwd(
|
||||
MIN_BLOCK_KV=_MIN_BLOCK_KV,
|
||||
BLOCK_DV=BLOCK_DV,
|
||||
Lv=Lv,
|
||||
HAS_SK=HAS_SK,
|
||||
num_warps=4,
|
||||
num_stages=2,
|
||||
**extra_kargs,
|
||||
@@ -609,6 +619,7 @@ def decode_attention_fwd_normal(
|
||||
max_kv_splits,
|
||||
sm_scale,
|
||||
logit_cap=0.0,
|
||||
sk=None,
|
||||
):
|
||||
_decode_att_m_fwd(
|
||||
q,
|
||||
@@ -632,6 +643,7 @@ def decode_attention_fwd_normal(
|
||||
kv_indptr,
|
||||
num_kv_splits,
|
||||
max_kv_splits,
|
||||
sk,
|
||||
)
|
||||
|
||||
|
||||
@@ -648,6 +660,7 @@ def decode_attention_fwd_grouped(
|
||||
max_kv_splits,
|
||||
sm_scale,
|
||||
logit_cap=0.0,
|
||||
sk=None,
|
||||
):
|
||||
_decode_grouped_att_m_fwd(
|
||||
q,
|
||||
@@ -671,6 +684,7 @@ def decode_attention_fwd_grouped(
|
||||
kv_indptr,
|
||||
num_kv_splits,
|
||||
max_kv_splits,
|
||||
sk,
|
||||
)
|
||||
|
||||
|
||||
@@ -687,6 +701,7 @@ def decode_attention_fwd(
|
||||
max_kv_splits,
|
||||
sm_scale,
|
||||
logit_cap=0.0,
|
||||
sk=None,
|
||||
):
|
||||
assert max_kv_splits == attn_logits.shape[2]
|
||||
assert q.shape[0] <= kv_indptr.shape[0] - 1
|
||||
@@ -709,6 +724,7 @@ def decode_attention_fwd(
|
||||
max_kv_splits,
|
||||
sm_scale,
|
||||
logit_cap=logit_cap,
|
||||
sk=sk,
|
||||
)
|
||||
else:
|
||||
# GQA/MQA/MLA
|
||||
@@ -725,4 +741,5 @@ def decode_attention_fwd(
|
||||
max_kv_splits,
|
||||
sm_scale,
|
||||
logit_cap=logit_cap,
|
||||
sk=sk,
|
||||
)
|
||||
|
||||
@@ -51,6 +51,7 @@ def _fwd_kernel(
|
||||
kv_indices,
|
||||
mask_ptr,
|
||||
mask_indptr,
|
||||
sk_ptr,
|
||||
sm_scale,
|
||||
kv_group_num,
|
||||
stride_qbs,
|
||||
@@ -78,6 +79,7 @@ def _fwd_kernel(
|
||||
IS_CAUSAL: tl.constexpr,
|
||||
SKIP_PREFIX_CUSTOM_MASK: tl.constexpr,
|
||||
STORE_TRANSPOSE: tl.constexpr,
|
||||
HAS_SK: tl.constexpr,
|
||||
):
|
||||
cur_seq = tl.program_id(0)
|
||||
cur_head = tl.program_id(1)
|
||||
@@ -178,13 +180,17 @@ def _fwd_kernel(
|
||||
final_mask &= custom_mask
|
||||
if SLIDING_WINDOW_SIZE > 0:
|
||||
# Add mask where q_id <= kv_id + sliding_window_size
|
||||
window_mask = (cur_block_m * BLOCK_M + offs_m[:, None]) <= (
|
||||
start_n + offs_n[None, :] + SLIDING_WINDOW_SIZE
|
||||
)
|
||||
# q_id = prefix_len + cur_m, kv_id = cur_n
|
||||
window_mask = (
|
||||
cur_seq_len_prefix + cur_block_m * BLOCK_M + offs_m[:, None]
|
||||
) <= (start_n + offs_n[None, :] + SLIDING_WINDOW_SIZE)
|
||||
final_mask &= window_mask
|
||||
qk = tl.where(final_mask, qk, float("-inf"))
|
||||
|
||||
n_e_max = tl.maximum(tl.max(qk, 1), e_max)
|
||||
row_max = tl.max(qk, 1)
|
||||
row_max_fixed = tl.where(row_max == float("-inf"), -1e20, row_max)
|
||||
n_e_max = tl.maximum(row_max_fixed, e_max)
|
||||
|
||||
re_scale = tl.exp(e_max - n_e_max)
|
||||
p = tl.exp(qk - n_e_max[:, None])
|
||||
deno = deno * re_scale + tl.sum(p, 1)
|
||||
@@ -242,6 +248,7 @@ def _fwd_kernel(
|
||||
if logit_cap > 0:
|
||||
qk = logit_cap * tanh(qk / logit_cap)
|
||||
|
||||
final_mask = mask_m[:, None] & mask_n[None, :]
|
||||
if USE_CUSTOM_MASK:
|
||||
custom_mask = tl.load(
|
||||
mask_ptr
|
||||
@@ -254,18 +261,30 @@ def _fwd_kernel(
|
||||
other=0,
|
||||
)
|
||||
custom_mask &= mask_m[:, None] & mask_n[None, :]
|
||||
qk = tl.where(custom_mask, qk, float("-inf"))
|
||||
final_mask &= custom_mask
|
||||
elif IS_CAUSAL:
|
||||
mask_causual = (cur_block_m * BLOCK_M + offs_m[:, None]) >= (
|
||||
start_n + offs_n[None, :]
|
||||
)
|
||||
mask_causual &= mask_m[:, None] & mask_n[None, :]
|
||||
qk = tl.where(mask_causual, qk, float("-inf"))
|
||||
final_mask &= mask_causual
|
||||
else:
|
||||
mask_non_causal = mask_m[:, None] & mask_n[None, :]
|
||||
qk = tl.where(mask_non_causal, qk, float("-inf"))
|
||||
final_mask &= mask_non_causal
|
||||
|
||||
if SLIDING_WINDOW_SIZE > 0:
|
||||
# Add mask where q_id <= kv_id + sliding_window_size
|
||||
window_mask = (cur_block_m * BLOCK_M + offs_m[:, None]) <= (
|
||||
start_n + offs_n[None, :] + SLIDING_WINDOW_SIZE
|
||||
)
|
||||
final_mask &= window_mask
|
||||
|
||||
qk = tl.where(final_mask, qk, float("-inf"))
|
||||
|
||||
row_max = tl.max(qk, 1)
|
||||
row_max_fixed = tl.where(row_max == float("-inf"), -1e20, row_max)
|
||||
n_e_max = tl.maximum(row_max_fixed, e_max)
|
||||
|
||||
n_e_max = tl.maximum(tl.max(qk, 1), e_max)
|
||||
re_scale = tl.exp(e_max - n_e_max)
|
||||
p = tl.exp(qk - n_e_max[:, None])
|
||||
deno = deno * re_scale + tl.sum(p, 1)
|
||||
@@ -283,6 +302,10 @@ def _fwd_kernel(
|
||||
|
||||
e_max = n_e_max
|
||||
|
||||
if HAS_SK:
|
||||
cur_sk = tl.load(sk_ptr + cur_head)
|
||||
deno += tl.exp(cur_sk - e_max)
|
||||
|
||||
offs_o = (
|
||||
(cur_seq_extend_start_idx + cur_block_m * BLOCK_M + offs_m[:, None])
|
||||
* stride_obs
|
||||
@@ -321,6 +344,7 @@ def extend_attention_fwd(
|
||||
logit_cap=0.0,
|
||||
skip_prefix_custom_mask=True,
|
||||
sliding_window_size=-1,
|
||||
sk=None,
|
||||
):
|
||||
"""
|
||||
q_extend, k_extend, v_extend, o_extend: contiguous tensors
|
||||
@@ -386,6 +410,8 @@ def extend_attention_fwd(
|
||||
# Skip custom mask for prefix part
|
||||
SKIP_PREFIX_CUSTOM_MASK = skip_prefix_custom_mask
|
||||
|
||||
HAS_SK = sk is not None
|
||||
|
||||
grid = (batch_size, head_num, triton.cdiv(max_len_extend, BLOCK_M))
|
||||
num_stages = 1
|
||||
|
||||
@@ -405,6 +431,7 @@ def extend_attention_fwd(
|
||||
kv_indices,
|
||||
custom_mask,
|
||||
mask_indptr,
|
||||
sk,
|
||||
sm_scale,
|
||||
kv_group_num,
|
||||
q_extend.stride(0),
|
||||
@@ -431,6 +458,7 @@ def extend_attention_fwd(
|
||||
USE_CUSTOM_MASK=USE_CUSTOM_MASK,
|
||||
IS_CAUSAL=is_causal,
|
||||
SKIP_PREFIX_CUSTOM_MASK=SKIP_PREFIX_CUSTOM_MASK,
|
||||
HAS_SK=HAS_SK,
|
||||
STORE_TRANSPOSE=_is_hip,
|
||||
num_warps=num_warps,
|
||||
num_stages=num_stages,
|
||||
|
||||
@@ -1191,11 +1191,6 @@ class RowParallelLinear(LinearBase):
|
||||
else self.weight_loader
|
||||
),
|
||||
)
|
||||
if not reduce_results and (bias and not skip_bias_add):
|
||||
raise ValueError(
|
||||
"When not reduce the results, adding bias to the "
|
||||
"results can lead to incorrect results"
|
||||
)
|
||||
|
||||
if bias:
|
||||
self.bias = Parameter(torch.empty(self.output_size, dtype=params_dtype))
|
||||
|
||||
@@ -134,6 +134,10 @@ class FusedMoE(torch.nn.Module):
|
||||
no_combine: bool = False,
|
||||
routed_scaling_factor: Optional[float] = None,
|
||||
enable_flashinfer_cutlass_moe: Optional[bool] = False,
|
||||
activation_alpha: Optional[float] = None,
|
||||
swiglu_limit: Optional[float] = None,
|
||||
use_weight_loader_fused: bool = False,
|
||||
with_bias=False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@@ -148,6 +152,10 @@ class FusedMoE(torch.nn.Module):
|
||||
self.expert_map_cpu = None
|
||||
self.expert_map_gpu = None
|
||||
|
||||
# For activation
|
||||
self.activation_alpha = activation_alpha
|
||||
self.swiglu_limit = swiglu_limit
|
||||
|
||||
if enable_flashinfer_cutlass_moe and quant_config is None:
|
||||
logger.warning("Disable flashinfer MoE when quantization config is None.")
|
||||
enable_flashinfer_cutlass_moe = False
|
||||
@@ -191,7 +199,7 @@ class FusedMoE(torch.nn.Module):
|
||||
|
||||
if quant_config is None:
|
||||
self.quant_method: Optional[QuantizeMethodBase] = UnquantizedFusedMoEMethod(
|
||||
self.use_triton_kernels
|
||||
self.use_triton_kernels, with_bias=with_bias
|
||||
)
|
||||
else:
|
||||
self.quant_method = quant_config.get_quant_method(self, prefix)
|
||||
@@ -206,7 +214,12 @@ class FusedMoE(torch.nn.Module):
|
||||
intermediate_size=self.intermediate_size_per_partition,
|
||||
intermediate_size_per_partition=self.intermediate_size_per_partition,
|
||||
params_dtype=params_dtype,
|
||||
weight_loader=self.weight_loader,
|
||||
weight_loader=(
|
||||
self.weight_loader
|
||||
if not use_weight_loader_fused
|
||||
else self.weight_loader_fused
|
||||
),
|
||||
with_bias=with_bias,
|
||||
)
|
||||
|
||||
def _load_per_tensor_weight_scale(
|
||||
@@ -234,6 +247,7 @@ class FusedMoE(torch.nn.Module):
|
||||
shard_id: str,
|
||||
loaded_weight: torch.Tensor,
|
||||
tp_rank: int,
|
||||
is_bias: bool = False,
|
||||
):
|
||||
# Load grouped weight scales for group quantization
|
||||
# or model weights
|
||||
@@ -244,14 +258,16 @@ class FusedMoE(torch.nn.Module):
|
||||
loaded_weight=loaded_weight,
|
||||
expert_data=expert_data,
|
||||
tp_rank=tp_rank,
|
||||
is_bias=is_bias,
|
||||
)
|
||||
elif shard_id in ("w1", "w3"):
|
||||
elif shard_id in ("w1", "w3", "w13"):
|
||||
self._load_w13(
|
||||
shard_id=shard_id,
|
||||
shard_dim=shard_dim,
|
||||
loaded_weight=loaded_weight,
|
||||
expert_data=expert_data,
|
||||
tp_rank=tp_rank,
|
||||
is_bias=is_bias,
|
||||
)
|
||||
|
||||
def _load_per_channel_weight_scale(
|
||||
@@ -281,17 +297,30 @@ class FusedMoE(torch.nn.Module):
|
||||
shard_id: str,
|
||||
loaded_weight: torch.Tensor,
|
||||
tp_rank: int,
|
||||
is_bias: bool = False,
|
||||
):
|
||||
|
||||
# Index the loaded weight for tp sharding.
|
||||
# gate_up_proj: "MergedColumnParallel", so tp sharding on output_dim
|
||||
shard_size = expert_data.shape[shard_dim] // 2
|
||||
assert shard_id in {"w1", "w3", "w13"}
|
||||
|
||||
if is_bias:
|
||||
# if this weight is a bias, the last dimension must be the sharded dimension
|
||||
shard_dim = -1
|
||||
|
||||
if shard_id in {"w1", "w3"}:
|
||||
# non-fused version
|
||||
shard_size = expert_data.shape[shard_dim] // 2
|
||||
elif shard_id in {"w13"}:
|
||||
# fused version
|
||||
shard_size = expert_data.shape[shard_dim]
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
# Narrow parameter and load.
|
||||
# w1, gate_proj: Load into first logical weight of w13.
|
||||
# w3, up_proj: Load into second logical weight of w13.
|
||||
# trtllm cutlass kernel assumes differently
|
||||
assert shard_id in ("w1", "w3")
|
||||
switch_w13 = getattr(self.quant_method, "load_up_proj_weight_first", False)
|
||||
if (switch_w13 and shard_id == "w1") or (not switch_w13 and shard_id == "w3"):
|
||||
start = shard_size
|
||||
@@ -310,7 +339,8 @@ class FusedMoE(torch.nn.Module):
|
||||
)
|
||||
else:
|
||||
if not self.use_presharded_weights:
|
||||
if self.use_triton_kernels:
|
||||
if not is_bias and self.use_triton_kernels:
|
||||
# do not transpose for bias
|
||||
loaded_weight = loaded_weight.transpose(-2, -1)
|
||||
loaded_weight = loaded_weight.narrow(
|
||||
shard_dim, shard_size * tp_rank, shard_size
|
||||
@@ -326,6 +356,7 @@ class FusedMoE(torch.nn.Module):
|
||||
shard_id: str,
|
||||
loaded_weight: torch.Tensor,
|
||||
tp_rank: int,
|
||||
is_bias: bool = False,
|
||||
):
|
||||
"""Load w2 weights for down projection.
|
||||
|
||||
@@ -356,7 +387,14 @@ class FusedMoE(torch.nn.Module):
|
||||
# Index the loaded weight for tp sharding.
|
||||
# down_proj: "RowParallel" so tp sharding on input_dim
|
||||
# Narrow parameter and load.
|
||||
shard_size = expert_data.shape[shard_dim]
|
||||
if is_bias:
|
||||
# this expert_data is a bias, not weight,
|
||||
# for w2_bias in TP, it does not need to be sharded
|
||||
shard_size = expert_data.shape[-1]
|
||||
else:
|
||||
# this parameter is a weight matrix
|
||||
# for w2 in TP, it shards the input_features, i.e., shard_dim=2
|
||||
shard_size = expert_data.shape[shard_dim]
|
||||
|
||||
if _is_cpu:
|
||||
expert_data, loaded_weight = narrow_padded_param_and_loaded_weight(
|
||||
@@ -369,7 +407,7 @@ class FusedMoE(torch.nn.Module):
|
||||
not self.use_presharded_weights,
|
||||
)
|
||||
else:
|
||||
if not self.use_presharded_weights:
|
||||
if not is_bias and not self.use_presharded_weights:
|
||||
if self.use_triton_kernels:
|
||||
loaded_weight = loaded_weight.transpose(-2, -1)
|
||||
if shard_size * tp_rank + shard_size > loaded_weight.shape[shard_dim]:
|
||||
@@ -658,6 +696,68 @@ class FusedMoE(torch.nn.Module):
|
||||
)
|
||||
return
|
||||
|
||||
def weight_loader_fused(
|
||||
self,
|
||||
param: torch.nn.Parameter,
|
||||
loaded_weight: torch.Tensor,
|
||||
weight_name: str,
|
||||
shard_id: str,
|
||||
) -> None:
|
||||
tp_rank = self.moe_tp_rank
|
||||
|
||||
# compressed-tensors checkpoints with packed weights are stored flipped
|
||||
# TODO: check self.quant_method.quant_config.quant_format
|
||||
# against known CompressionFormat enum values that have this quality
|
||||
loaded_weight = (
|
||||
loaded_weight.t().contiguous()
|
||||
if (
|
||||
self.quant_method.__class__.__name__
|
||||
== "CompressedTensorsWNA16MoEMethod"
|
||||
)
|
||||
else loaded_weight
|
||||
)
|
||||
|
||||
if shard_id not in ("w13", "w2"):
|
||||
raise ValueError(f"shard_id must be ['w13','w2'] but " f"got {shard_id}.")
|
||||
|
||||
# Fetch the dim to shard the parameter/loaded weight
|
||||
# based on the shard id. This will be whatever
|
||||
# dimension intermediate_size is used.
|
||||
SHARD_ID_TO_SHARDED_DIM = {"w13": 1, "w2": 2}
|
||||
SHARD_ID_TO_SHARDED_DIM_TRANSPOSE = {"w13": 2, "w2": 1}
|
||||
|
||||
expert_data = param.data
|
||||
is_bias = expert_data.dim() == 2
|
||||
|
||||
# is_transposed: if the dim to shard the weight
|
||||
# should be flipped. Required by GPTQ, compressed-tensors
|
||||
# should be whatever dimension intermediate_size is
|
||||
is_transposed = getattr(param, "is_transposed", False)
|
||||
|
||||
if self.use_triton_kernels:
|
||||
is_transposed = True
|
||||
shard_dim = (
|
||||
SHARD_ID_TO_SHARDED_DIM[shard_id]
|
||||
if not is_transposed
|
||||
else SHARD_ID_TO_SHARDED_DIM_TRANSPOSE[shard_id]
|
||||
)
|
||||
|
||||
# Case model weights
|
||||
if "weight" in weight_name:
|
||||
self._load_model_weight_or_group_weight_scale(
|
||||
shard_id=shard_id,
|
||||
shard_dim=shard_dim,
|
||||
loaded_weight=loaded_weight,
|
||||
expert_data=expert_data,
|
||||
tp_rank=tp_rank,
|
||||
is_bias=is_bias,
|
||||
)
|
||||
return
|
||||
else:
|
||||
logging.warning(
|
||||
f"Unsupported weight_name {weight_name} for FusedMoE weight_loader_fused. Nothing is loaded."
|
||||
)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor, topk_output: StandardTopKOutput):
|
||||
assert self.quant_method is not None
|
||||
|
||||
@@ -673,6 +773,12 @@ class FusedMoE(torch.nn.Module):
|
||||
|
||||
# Matrix multiply.
|
||||
with use_symmetric_memory(get_tp_group()) as sm:
|
||||
kwargs = {}
|
||||
if self.activation_alpha is not None:
|
||||
kwargs["activation_alpha"] = self.activation_alpha
|
||||
if self.swiglu_limit is not None:
|
||||
kwargs["swiglu_limit"] = self.swiglu_limit
|
||||
|
||||
final_hidden_states = self.quant_method.apply(
|
||||
layer=self,
|
||||
x=hidden_states,
|
||||
@@ -691,6 +797,7 @@ class FusedMoE(torch.nn.Module):
|
||||
== "ModelOptNvFp4FusedMoEMethod"
|
||||
else {}
|
||||
),
|
||||
**kwargs,
|
||||
)
|
||||
sm.tag(final_hidden_states)
|
||||
|
||||
@@ -728,6 +835,25 @@ class FusedMoE(torch.nn.Module):
|
||||
]
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def make_expert_params_mapping_fused(
|
||||
cls,
|
||||
ckpt_gate_up_proj_name: str,
|
||||
ckpt_down_proj_name: str,
|
||||
ckpt_gate_up_proj_bias_name: str,
|
||||
ckpt_down_proj_bias_name: str,
|
||||
):
|
||||
return [
|
||||
("experts.w13_weight", f"experts.{ckpt_gate_up_proj_name}", "w13"),
|
||||
(
|
||||
"experts.w13_weight_bias",
|
||||
f"experts.{ckpt_gate_up_proj_bias_name}",
|
||||
"w13",
|
||||
),
|
||||
("experts.w2_weight", f"experts.{ckpt_down_proj_name}", "w2"),
|
||||
("experts.w2_weight_bias", f"experts.{ckpt_down_proj_bias_name}", "w2"),
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def make_expert_input_scale_params_mapping(
|
||||
cls,
|
||||
|
||||
@@ -6,15 +6,50 @@ from typing import TYPE_CHECKING, Optional
|
||||
|
||||
import torch
|
||||
from sgl_kernel import gelu_and_mul, silu_and_mul
|
||||
from triton_kernels.matmul_ogs import matmul_ogs
|
||||
from triton_kernels.matmul_ogs import (
|
||||
FlexCtx,
|
||||
FnSpecs,
|
||||
FusedActivation,
|
||||
PrecisionConfig,
|
||||
matmul_ogs,
|
||||
)
|
||||
from triton_kernels.numerics import InFlexData
|
||||
from triton_kernels.routing import GatherIndx, RoutingData, ScatterIndx
|
||||
|
||||
from sglang.srt.utils import direct_register_custom_op
|
||||
from triton_kernels.swiglu import swiglu_fn
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.layers.moe.topk import TopKOutput
|
||||
|
||||
|
||||
def quantize(w, dtype, dev, **opt):
|
||||
if dtype == "bf16":
|
||||
return w.to(torch.bfloat16), InFlexData()
|
||||
elif dtype == "fp8":
|
||||
wq = w.to(torch.float8_e4m3fn).transpose(-1, -2).contiguous().transpose(-1, -2)
|
||||
return (
|
||||
wq,
|
||||
InFlexData(dtype=wq.dtype, scale=w.abs().max().unsqueeze(0)),
|
||||
MicroscalingCtx(),
|
||||
)
|
||||
else:
|
||||
assert dtype == "mx4", f"{dtype=}"
|
||||
swizzle_mx_scale = opt["swizzle_mx_scale"]
|
||||
swizzle_axis = 2 if swizzle_mx_scale else None
|
||||
w = w.to(torch.bfloat16)
|
||||
w, mx_scales, weight_scale_shape = downcast_to_mxfp(
|
||||
w, torch.uint8, axis=1, swizzle_axis=swizzle_axis
|
||||
)
|
||||
return (
|
||||
w,
|
||||
InFlexData(),
|
||||
MicroscalingCtx(
|
||||
weight_scale=mx_scales,
|
||||
swizzle_mx=swizzle_mx_scale,
|
||||
actual_weight_scale_shape=weight_scale_shape,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def triton_kernel_moe_forward(
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
@@ -146,3 +181,143 @@ def triton_kernel_fused_experts(
|
||||
)
|
||||
|
||||
return intermediate_cache3
|
||||
|
||||
|
||||
def triton_kernel_moe_with_bias_forward(
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
b1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
b2: torch.Tensor,
|
||||
topk_output: TopKOutput,
|
||||
inplace: bool = False,
|
||||
activation: str = "silu",
|
||||
use_fp8_w8a8: bool = False,
|
||||
per_channel_quant: bool = False,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
w1_scale: Optional[torch.Tensor] = None,
|
||||
w2_scale: Optional[torch.Tensor] = None,
|
||||
a1_scale: Optional[torch.Tensor] = None,
|
||||
a2_scale: Optional[torch.Tensor] = None,
|
||||
block_shape: Optional[list[int]] = None,
|
||||
activation_alpha: Optional[float] = None,
|
||||
swiglu_limit: Optional[int] = None,
|
||||
) -> torch.Tensor:
|
||||
assert topk_output.format.is_triton_kernel()
|
||||
routing_data, gather_idx, scatter_idx = topk_output
|
||||
|
||||
return triton_kernel_fused_experts_with_bias(
|
||||
hidden_states,
|
||||
w1,
|
||||
b1,
|
||||
w2,
|
||||
b2,
|
||||
routing_data,
|
||||
gather_idx,
|
||||
scatter_idx,
|
||||
inplace=inplace,
|
||||
activation=activation,
|
||||
use_fp8_w8a8=use_fp8_w8a8,
|
||||
per_channel_quant=per_channel_quant,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
a1_scale=a1_scale,
|
||||
a2_scale=a2_scale,
|
||||
block_shape=block_shape,
|
||||
activation_alpha=activation_alpha,
|
||||
swiglu_limit=swiglu_limit,
|
||||
)
|
||||
|
||||
|
||||
def triton_kernel_fused_experts_with_bias(
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
b1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
b2: torch.Tensor,
|
||||
routing_data: RoutingData,
|
||||
gather_indx: GatherIndx,
|
||||
scatter_indx: ScatterIndx,
|
||||
inplace: bool = False,
|
||||
activation: str = "silu",
|
||||
use_fp8_w8a8: bool = False,
|
||||
per_channel_quant: bool = False,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
w1_scale: Optional[torch.Tensor] = None,
|
||||
w2_scale: Optional[torch.Tensor] = None,
|
||||
a1_scale: Optional[torch.Tensor] = None,
|
||||
a2_scale: Optional[torch.Tensor] = None,
|
||||
block_shape: Optional[list[int]] = None,
|
||||
activation_alpha: Optional[float] = None,
|
||||
swiglu_limit: Optional[int] = None,
|
||||
) -> torch.Tensor:
|
||||
# print(f"here in triton moe with bias", b1.shape, b1.dtype, b2.shape, b2.dtype)
|
||||
assert use_fp8_w8a8 == False, "use_fp8_w8a8 is not supported"
|
||||
assert per_channel_quant == False, "per_channel_quant is not supported"
|
||||
assert expert_map == None, "expert_map is not supported"
|
||||
assert w1_scale == None, "w1_scale is not supported"
|
||||
assert w2_scale == None, "w2_scale is not supported"
|
||||
assert a1_scale == None, "a1_scale is not supported"
|
||||
assert a2_scale == None, "a2_scale is not supported"
|
||||
assert block_shape == None, "block_shape is not supported"
|
||||
|
||||
# type check
|
||||
assert hidden_states.dtype == torch.bfloat16, "hidden_states must be bfloat16"
|
||||
assert w1.dtype == torch.bfloat16, "w1 must be bfloat16"
|
||||
assert w2.dtype == torch.bfloat16, "w2 must be bfloat16"
|
||||
|
||||
# Shape check
|
||||
assert hidden_states.ndim == 2, "hidden_states must be 2D"
|
||||
assert (
|
||||
hidden_states.shape[-1] == w1.shape[-2]
|
||||
), f"hidden_states shape[-1] {hidden_states.shape} must be equal to w1 shape[-2] {w1.shape}"
|
||||
assert (
|
||||
w2.shape[-1] == w1.shape[1]
|
||||
), f"w2 shape[-1] {w2.shape[-1]} must be equal to w1 shape[1] {w1.shape[1]}"
|
||||
|
||||
# feature check
|
||||
assert inplace == False, "Inplace is not supported in new triton MoE kernel"
|
||||
|
||||
E, _, _ = w1.shape
|
||||
|
||||
if global_num_experts == -1:
|
||||
global_num_experts = E
|
||||
|
||||
device = "cuda"
|
||||
optg = dict()
|
||||
w1, w1_flex = quantize(w1, "bf16", device, **optg)
|
||||
w1_pcg = PrecisionConfig(flex_ctx=FlexCtx(rhs_data=w1_flex))
|
||||
|
||||
w2, w2_flex = quantize(w2, "bf16", device, **optg)
|
||||
w2_pcg = PrecisionConfig(flex_ctx=FlexCtx(rhs_data=w2_flex))
|
||||
|
||||
act = FusedActivation(
|
||||
FnSpecs("swiglu", swiglu_fn, ("alpha", "limit")),
|
||||
(activation_alpha, swiglu_limit),
|
||||
2,
|
||||
)
|
||||
|
||||
intermediate_cache = matmul_ogs(
|
||||
hidden_states,
|
||||
w1,
|
||||
b1,
|
||||
routing_data,
|
||||
gather_indx=gather_indx,
|
||||
precision_config=w1_pcg,
|
||||
gammas=None,
|
||||
fused_activation=act,
|
||||
)
|
||||
|
||||
return matmul_ogs(
|
||||
intermediate_cache,
|
||||
w2,
|
||||
b2,
|
||||
routing_data,
|
||||
scatter_indx=scatter_indx,
|
||||
precision_config=w2_pcg,
|
||||
gammas=routing_data.gate_scal,
|
||||
)
|
||||
|
||||
@@ -4,6 +4,7 @@ import torch
|
||||
|
||||
from sglang.srt.layers.quantization import deep_gemm_wrapper
|
||||
from sglang.srt.layers.quantization.fp8_kernel import sglang_per_token_group_quant_fp8
|
||||
from sglang.srt.layers.quantization.mxfp4_tensor import MXFP4QuantizeUtil
|
||||
from sglang.srt.layers.utils import is_sm100_supported
|
||||
|
||||
try:
|
||||
@@ -26,6 +27,7 @@ from sglang.srt.layers.quantization.fp8_kernel import (
|
||||
)
|
||||
from sglang.srt.utils import (
|
||||
align,
|
||||
ceil_div,
|
||||
get_bool_env_var,
|
||||
get_cuda_version,
|
||||
get_device_capability,
|
||||
@@ -307,6 +309,33 @@ def triton_w8a8_block_fp8_linear(
|
||||
return output.to(dtype=input_2d.dtype).view(*output_shape)
|
||||
|
||||
|
||||
def dequant_mxfp4(
|
||||
w_block: torch.Tensor,
|
||||
w_scale: torch.Tensor,
|
||||
out_dtype,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
:param w_block: (batch, n, k, 16), uint8, pack two mxfp4 into one byte
|
||||
:param w_scale: (batch, n, k), uint8
|
||||
:return: (batch, n, k * 32), float32
|
||||
"""
|
||||
|
||||
assert w_block.dtype == torch.uint8
|
||||
assert w_scale.dtype == torch.uint8
|
||||
|
||||
batch, n, k, pack_dim = w_block.shape
|
||||
batch_, n_, k_ = w_scale.shape
|
||||
assert pack_dim == 16
|
||||
assert batch == batch_
|
||||
assert n == n_
|
||||
assert k == k_
|
||||
|
||||
out_raw = MXFP4QuantizeUtil.dequantize(
|
||||
quantized_data=w_block, scale=w_scale, dtype=out_dtype, block_sizes=[32]
|
||||
)
|
||||
return out_raw.reshape(batch, n, k * 32)
|
||||
|
||||
|
||||
def input_to_float8(
|
||||
x: torch.Tensor, dtype: torch.dtype = fp8_dtype
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
|
||||
133
python/sglang/srt/layers/quantization/mxfp4_tensor.py
Normal file
133
python/sglang/srt/layers/quantization/mxfp4_tensor.py
Normal file
@@ -0,0 +1,133 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
# https://github.com/NVIDIA/TensorRT-Model-Optimizer/blob/main/modelopt/torch/quantization/qtensor/mxfp4_tensor.py
|
||||
class MXFP4QuantizeUtil:
|
||||
E2M1_max = 6.0
|
||||
|
||||
E2M1_values = [0, 0.5, 1, 1.5, 2, 3, 4, 6]
|
||||
E2M1_bounds = torch.tensor([0.25, 0.75, 1.25, 1.75, 2.5, 3.5, 5])
|
||||
|
||||
@classmethod
|
||||
def quantize(cls, input: torch.Tensor, block_size: int | None) -> tuple:
|
||||
"""Converting a tensor to a quantized format based on MXFP4 quantization. Only E4M3 is supported.
|
||||
Args:
|
||||
input (torch.Tensor): The input tensor to be quantized.
|
||||
block_sizes (dict | None): The block sizes for quantization.
|
||||
"""
|
||||
|
||||
def cast_fp4(x):
|
||||
sign = torch.sign(x)
|
||||
sign_bit = (2 - sign) // 2
|
||||
ord_ = torch.sum(
|
||||
(x.abs().unsqueeze(-1) - cls.E2M1_bounds.to(x.device)) > 0, dim=-1
|
||||
)
|
||||
fp4_val = (sign_bit * 0b1000 + ord_).to(torch.uint8)
|
||||
return fp4_val
|
||||
|
||||
def fuse_uint4_to_uint8(x):
|
||||
# If the last dimension is odd, pad with zeros
|
||||
# If this behavior is not desired, please modify the code accordingly
|
||||
left_side = x[..., 0::2] # Even indices (0, 2, 4...)
|
||||
right_side = x[..., 1::2] # Odd indices (1, 3, 5...)
|
||||
new_data = (
|
||||
right_side.clone() << 4
|
||||
) # Put odd indices (higher addresses) in high bits
|
||||
new_data[
|
||||
..., : left_side.shape[-1]
|
||||
] += left_side # Put even indices in low bits
|
||||
return new_data
|
||||
|
||||
if block_size is None:
|
||||
block_size = 32
|
||||
|
||||
original_shape = input.shape
|
||||
original_dtype = input.dtype
|
||||
input = input.view(-1, block_size)
|
||||
# get scales
|
||||
input_amax = input.abs().max(dim=-1, keepdim=True).values
|
||||
descale = input_amax / cls.E2M1_max
|
||||
min_value = torch.tensor(-127.0, device=descale.device)
|
||||
e8m0_scale = torch.ceil(torch.maximum(torch.log2(descale), min_value))
|
||||
|
||||
input = (input / torch.exp2(e8m0_scale)).view(original_shape)
|
||||
input_q = cast_fp4(input)
|
||||
input_q = fuse_uint4_to_uint8(input_q)
|
||||
e8m0_scale = (e8m0_scale + 127).to(torch.uint8)
|
||||
return cls(original_shape, original_dtype, input_q), e8m0_scale
|
||||
|
||||
@classmethod
|
||||
def dequantize(cls, quantized_data, dtype: torch.dtype, scale, block_sizes):
|
||||
"""Dequantze MXFP4 packed tensor to a target dtype."""
|
||||
|
||||
def unfuse_uint8_to_uint4(x):
|
||||
"""Unfuse uint8 values back to uint4 values.
|
||||
This is the inverse operation of fuse_uint4_to_uint8.
|
||||
"""
|
||||
# Extract the lower 4 bits (even indices)
|
||||
left_side = x & 0x0F
|
||||
|
||||
# Extract the upper 4 bits (odd indices)
|
||||
right_side = (x >> 4) & 0x0F
|
||||
|
||||
# Create a new tensor with alternating values
|
||||
shape = list(x.shape)
|
||||
shape[-1] = shape[-1] * 2
|
||||
result = torch.zeros(shape, dtype=torch.uint8, device=x.device)
|
||||
|
||||
# Fill in the values - even indices get low bits, odd indices get high bits
|
||||
result[..., 0::2] = left_side # Even indices from low bits
|
||||
result[..., 1::2] = right_side # Odd indices from high bits
|
||||
|
||||
return result
|
||||
|
||||
e8m0_scale = scale
|
||||
block_size = block_sizes[-1]
|
||||
|
||||
# Unfuse the uint8 values back to uint4
|
||||
x_unfused = unfuse_uint8_to_uint4(quantized_data)
|
||||
# Extract sign and magnitude
|
||||
sign = 1 - 2 * ((x_unfused & 0b1000) >> 3).to(
|
||||
torch.float32
|
||||
) # Extract sign bit and convert to +1/-1
|
||||
magnitude = x_unfused & 0b0111 # Extract magnitude bits
|
||||
magnitude = magnitude.to(torch.long)
|
||||
|
||||
# Create a tensor with the E2M1 values
|
||||
values = torch.tensor(cls.E2M1_values, device=quantized_data.device)
|
||||
|
||||
# Use gather to index the values tensor properly
|
||||
# We need to reshape magnitude to match the dimensions we want to gather along
|
||||
original_shape = magnitude.shape
|
||||
x_float = values[magnitude.reshape(-1)].reshape(original_shape)
|
||||
|
||||
# Apply sign and scale
|
||||
x_float = sign.float() * x_float
|
||||
|
||||
# Reshape to apply block-wise scaling
|
||||
x_float = x_float.reshape(-1, block_size)
|
||||
|
||||
# Apply the E8M0 scale
|
||||
scale_factor = torch.exp2(e8m0_scale.float() - 127)
|
||||
scale_factor = scale_factor.reshape(-1, 1) # Reshape for proper broadcasting
|
||||
|
||||
# Apply scaling and reshape back to original shape
|
||||
x_float = x_float * scale_factor
|
||||
|
||||
# Reshape back to the original shape
|
||||
return x_float.reshape(original_shape).to(dtype)
|
||||
@@ -126,17 +126,23 @@ class UnquantizedLinearMethod(LinearMethodBase):
|
||||
class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
"""MoE method without quantization."""
|
||||
|
||||
def __init__(self, use_triton_kernels: bool = False):
|
||||
def __init__(self, use_triton_kernels: bool = False, with_bias: bool = False):
|
||||
super().__init__()
|
||||
self.use_triton_kernels = use_triton_kernels
|
||||
self.with_bias = with_bias
|
||||
|
||||
self.triton_kernel_moe_forward = None
|
||||
self.triton_kernel_moe_with_bias_forward = None
|
||||
if torch.cuda.is_available() and has_triton_kernels:
|
||||
from sglang.srt.layers.moe.fused_moe_triton.triton_kernels_moe import (
|
||||
triton_kernel_moe_forward as _tk_forward,
|
||||
)
|
||||
from sglang.srt.layers.moe.fused_moe_triton.triton_kernels_moe import (
|
||||
triton_kernel_moe_with_bias_forward as _tk_with_bias_forward,
|
||||
)
|
||||
|
||||
self.triton_kernel_moe_forward = _tk_forward
|
||||
self.triton_kernel_moe_with_bias_forward = _tk_with_bias_forward
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
@@ -158,6 +164,14 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
layer.register_parameter("w13_weight", w13_weight)
|
||||
set_weight_attrs(w13_weight, extra_weight_attrs)
|
||||
|
||||
if self.with_bias:
|
||||
w13_weight_bias = torch.nn.Parameter(
|
||||
torch.empty(num_experts, 2 * intermediate_size, dtype=torch.float32),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w13_weight_bias", w13_weight_bias)
|
||||
set_weight_attrs(w13_weight_bias, extra_weight_attrs)
|
||||
|
||||
# down_proj (row parallel)
|
||||
w2_weight_n, w2_weight_k = (
|
||||
hidden_size,
|
||||
@@ -172,6 +186,14 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
layer.register_parameter("w2_weight", w2_weight)
|
||||
set_weight_attrs(w2_weight, extra_weight_attrs)
|
||||
|
||||
if self.with_bias:
|
||||
w2_weight_bias = torch.nn.Parameter(
|
||||
torch.empty(num_experts, hidden_size, dtype=torch.float32),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w2_weight_bias", w2_weight_bias)
|
||||
set_weight_attrs(w2_weight_bias, extra_weight_attrs)
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
if _use_aiter:
|
||||
layer.w13_weight = torch.nn.Parameter(
|
||||
@@ -202,7 +224,14 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
inplace: bool = True,
|
||||
no_combine: bool = False,
|
||||
routed_scaling_factor: Optional[float] = None,
|
||||
activation_alpha: Optional[float] = None,
|
||||
swiglu_limit: Optional[float] = None,
|
||||
) -> torch.Tensor:
|
||||
kwargs = {}
|
||||
if activation_alpha is not None:
|
||||
kwargs["activation_alpha"] = activation_alpha
|
||||
if swiglu_limit is not None:
|
||||
kwargs["swiglu_limit"] = swiglu_limit
|
||||
|
||||
return self.forward(
|
||||
x=x,
|
||||
@@ -213,6 +242,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
inplace=inplace,
|
||||
no_combine=no_combine,
|
||||
routed_scaling_factor=routed_scaling_factor,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def forward_cuda(
|
||||
@@ -226,15 +256,30 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
inplace: bool = True,
|
||||
no_combine: bool = False,
|
||||
routed_scaling_factor: Optional[float] = None,
|
||||
activation_alpha: Optional[float] = None,
|
||||
swiglu_limit: Optional[float] = None,
|
||||
) -> torch.Tensor:
|
||||
|
||||
if self.use_triton_kernels:
|
||||
return self.triton_kernel_moe_forward(
|
||||
hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
w2=layer.w2_weight,
|
||||
topk_output=topk_output,
|
||||
)
|
||||
if self.with_bias:
|
||||
return self.triton_kernel_moe_with_bias_forward(
|
||||
hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
w2=layer.w2_weight,
|
||||
b1=layer.w13_weight_bias,
|
||||
b2=layer.w2_weight_bias,
|
||||
topk_output=topk_output,
|
||||
activation=activation,
|
||||
activation_alpha=activation_alpha,
|
||||
swiglu_limit=swiglu_limit,
|
||||
)
|
||||
else:
|
||||
return self.triton_kernel_moe_forward(
|
||||
hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
w2=layer.w2_weight,
|
||||
topk_output=topk_output,
|
||||
)
|
||||
else:
|
||||
if _use_aiter:
|
||||
assert not no_combine, "unsupported"
|
||||
|
||||
@@ -917,8 +917,10 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
||||
|
||||
is_hybrid = False
|
||||
if isinstance(token_to_kv_pool_allocator, SWATokenToKVPoolAllocator):
|
||||
assert isinstance(tree_cache, SWARadixCache) or isinstance(
|
||||
tree_cache, SWAChunkCache
|
||||
assert (
|
||||
tree_cache is None
|
||||
or isinstance(tree_cache, SWARadixCache)
|
||||
or isinstance(tree_cache, SWAChunkCache)
|
||||
), "SWARadixCache or SWAChunkCache is required for SWATokenToKVPoolAllocator"
|
||||
is_hybrid = True
|
||||
|
||||
|
||||
923
python/sglang/srt/models/gpt_oss.py
Normal file
923
python/sglang/srt/models/gpt_oss.py
Normal file
@@ -0,0 +1,923 @@
|
||||
# Copyright 2023-2024 SGLang Team
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
"""Inference-only GptOss model compatible with HuggingFace weights."""
|
||||
|
||||
import logging
|
||||
from collections.abc import Iterable
|
||||
from functools import partial
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from sglang.srt.distributed import (
|
||||
get_moe_tensor_parallel_rank,
|
||||
get_pp_group,
|
||||
get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
tensor_model_parallel_all_reduce,
|
||||
)
|
||||
from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
|
||||
from sglang.srt.eplb.expert_location import ModelConfigForExpertLocation
|
||||
from sglang.srt.layers.communicator import LayerCommunicator, LayerScatterModes
|
||||
from sglang.srt.layers.dp_attention import (
|
||||
get_attention_tp_rank,
|
||||
get_attention_tp_size,
|
||||
get_local_attention_dp_size,
|
||||
)
|
||||
from sglang.srt.layers.layernorm import RMSNorm
|
||||
from sglang.srt.layers.linear import (
|
||||
QKVParallelLinear,
|
||||
ReplicatedLinear,
|
||||
RowParallelLinear,
|
||||
)
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||
from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class
|
||||
from sglang.srt.layers.moe.topk import TopK
|
||||
from sglang.srt.layers.moe.utils import DeepEPMode
|
||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
from sglang.srt.layers.quantization.fp8_utils import dequant_mxfp4
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from sglang.srt.layers.rotary_embedding import get_rope
|
||||
from sglang.srt.layers.utils import PPMissingLayer, get_layer_id
|
||||
from sglang.srt.layers.vocab_parallel_embedding import (
|
||||
ParallelLMHead,
|
||||
VocabParallelEmbedding,
|
||||
)
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
|
||||
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||
from sglang.srt.utils import add_prefix, make_layers
|
||||
|
||||
|
||||
class GptOssConfig(PretrainedConfig):
|
||||
model_type = "gpt_oss"
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Aligned with HF's implementation, using sliding window inclusive with the last token
|
||||
# SGLang assumes exclusive
|
||||
def get_attention_sliding_window_size(config):
|
||||
return config.sliding_window - 1
|
||||
|
||||
|
||||
class GptOssSparseMoeBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
layer_id: int,
|
||||
config: GptOssConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
self.layer_id = layer_id
|
||||
self.activation = config.hidden_act
|
||||
self.activation_alpha = getattr(config, "hidden_act_alpha", 1.702)
|
||||
self.swiglu_limit = config.swiglu_limit
|
||||
if self.tp_size > config.num_local_experts:
|
||||
raise ValueError(
|
||||
f"Tensor parallel size {self.tp_size} is greater than "
|
||||
f"the number of experts {config.num_local_experts}."
|
||||
)
|
||||
|
||||
self.topk = TopK(
|
||||
top_k=config.num_experts_per_tok,
|
||||
renormalize=True,
|
||||
)
|
||||
|
||||
experts_type = get_moe_impl_class()
|
||||
extra_kwargs = {}
|
||||
if experts_type.__name__ == "FusedMoE":
|
||||
extra_kwargs = {
|
||||
"enable_flashinfer_cutlass_moe": global_server_args_dict[
|
||||
"enable_flashinfer_cutlass_moe"
|
||||
],
|
||||
"use_weight_loader_fused": True, # for moe gate_up_proj and down_proj and their bias loading
|
||||
}
|
||||
self.experts = experts_type(
|
||||
num_experts=config.num_local_experts
|
||||
+ global_server_args_dict["ep_num_redundant_experts"],
|
||||
top_k=config.num_experts_per_tok,
|
||||
layer_id=layer_id,
|
||||
hidden_size=config.hidden_size,
|
||||
intermediate_size=config.intermediate_size,
|
||||
quant_config=quant_config,
|
||||
activation=self.activation,
|
||||
activation_alpha=self.activation_alpha,
|
||||
swiglu_limit=self.swiglu_limit,
|
||||
with_bias=True,
|
||||
prefix=add_prefix("experts", prefix),
|
||||
**(
|
||||
dict(deepep_mode=DeepEPMode[global_server_args_dict["deepep_mode"]])
|
||||
if global_server_args_dict["moe_a2a_backend"].is_deepep()
|
||||
else {}
|
||||
),
|
||||
**extra_kwargs,
|
||||
)
|
||||
|
||||
self.router = ReplicatedLinear(
|
||||
config.hidden_size,
|
||||
config.num_local_experts,
|
||||
bias=True,
|
||||
quant_config=None,
|
||||
prefix=add_prefix("gate", prefix),
|
||||
params_dtype=config.torch_dtype,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self, hidden_states: torch.Tensor, forward_batch: Optional[ForwardBatch] = None
|
||||
) -> torch.Tensor:
|
||||
if not global_server_args_dict["moe_a2a_backend"].is_deepep():
|
||||
return self.forward_normal(hidden_states)
|
||||
else:
|
||||
raise Exception("forward_deepep branch not implemented yet")
|
||||
|
||||
def get_moe_weights(self):
|
||||
return [
|
||||
x.data
|
||||
for name, x in self.experts.named_parameters()
|
||||
if name not in ["correction_bias"]
|
||||
]
|
||||
|
||||
def forward_normal(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
num_tokens, hidden_dim = hidden_states.shape
|
||||
hidden_states = hidden_states.view(-1, hidden_dim)
|
||||
|
||||
# router_logits: (num_tokens, n_experts)
|
||||
router_logits, _ = self.router(hidden_states)
|
||||
|
||||
kwargs = {"hidden_states": hidden_states}
|
||||
if self.topk is not None:
|
||||
kwargs["topk_output"] = self.topk(hidden_states, router_logits)
|
||||
else:
|
||||
kwargs["router_logits"] = router_logits
|
||||
final_hidden_states = self.experts(**kwargs)
|
||||
|
||||
if self.tp_size > 1:
|
||||
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
|
||||
|
||||
ans = final_hidden_states.view(num_tokens, hidden_dim)
|
||||
return ans
|
||||
|
||||
|
||||
class GptOssAttention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
num_heads: int,
|
||||
num_kv_heads: int,
|
||||
layer_id: int = 0,
|
||||
rope_theta: float = 10000,
|
||||
rope_scaling: Optional[Dict[str, Any]] = None,
|
||||
max_position_embeddings: int = 8192,
|
||||
head_dim: Optional[int] = None,
|
||||
rms_norm_eps: float = 1e-06,
|
||||
attention_bias: bool = False,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
sliding_window_size: int = -1, # if -1, normal attention, else, window attention.
|
||||
layer_type: str = "",
|
||||
params_dtype: torch.dtype = torch.bfloat16,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
self.sliding_window_size = sliding_window_size
|
||||
|
||||
attn_tp_rank = get_attention_tp_rank()
|
||||
attn_tp_size = get_attention_tp_size()
|
||||
|
||||
self.total_num_heads = num_heads
|
||||
assert self.total_num_heads % attn_tp_size == 0
|
||||
self.num_heads = self.total_num_heads // attn_tp_size
|
||||
self.total_num_kv_heads = num_kv_heads
|
||||
if self.total_num_kv_heads >= attn_tp_size:
|
||||
# Number of KV heads is greater than TP size, so we partition
|
||||
# the KV heads across multiple tensor parallel GPUs.
|
||||
assert self.total_num_kv_heads % attn_tp_size == 0
|
||||
else:
|
||||
# Number of KV heads is less than TP size, so we replicate
|
||||
# the KV heads across multiple tensor parallel GPUs.
|
||||
assert attn_tp_size % self.total_num_kv_heads == 0
|
||||
self.num_kv_heads = max(1, self.total_num_kv_heads // attn_tp_size)
|
||||
self.head_dim = head_dim or hidden_size // self.total_num_heads
|
||||
self.q_size = self.num_heads * self.head_dim
|
||||
self.kv_size = self.num_kv_heads * self.head_dim
|
||||
self.scaling = self.head_dim**-0.5
|
||||
self.rope_theta = rope_theta
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.tp_rank = get_tensor_model_parallel_rank()
|
||||
|
||||
self.qkv_proj = QKVParallelLinear(
|
||||
hidden_size,
|
||||
self.head_dim,
|
||||
self.total_num_heads,
|
||||
self.total_num_kv_heads,
|
||||
bias=attention_bias,
|
||||
params_dtype=params_dtype,
|
||||
quant_config=quant_config,
|
||||
tp_rank=attn_tp_rank,
|
||||
tp_size=attn_tp_size,
|
||||
prefix=add_prefix("qkv_proj", prefix),
|
||||
)
|
||||
|
||||
self.sinks = nn.Parameter(
|
||||
torch.empty(self.num_heads, dtype=params_dtype), requires_grad=False
|
||||
)
|
||||
|
||||
self.o_proj = RowParallelLinear(
|
||||
self.total_num_heads * self.head_dim,
|
||||
hidden_size,
|
||||
bias=attention_bias,
|
||||
quant_config=quant_config,
|
||||
tp_rank=attn_tp_rank,
|
||||
tp_size=attn_tp_size,
|
||||
reduce_results=False,
|
||||
params_dtype=params_dtype,
|
||||
prefix=add_prefix("o_proj", prefix),
|
||||
)
|
||||
|
||||
self.rotary_emb = get_rope(
|
||||
self.head_dim,
|
||||
rotary_dim=self.head_dim,
|
||||
max_position=max_position_embeddings,
|
||||
base=rope_theta,
|
||||
rope_scaling=rope_scaling,
|
||||
)
|
||||
|
||||
assert layer_type in {"sliding_attention", "full_attention"}
|
||||
use_sliding_window = layer_type == "sliding_attention"
|
||||
self.attn = RadixAttention(
|
||||
self.num_heads,
|
||||
self.head_dim,
|
||||
self.scaling,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
layer_id=layer_id,
|
||||
prefix=add_prefix("attn", prefix),
|
||||
sliding_window_size=(sliding_window_size if use_sliding_window else -1),
|
||||
)
|
||||
self.layer_id = layer_id
|
||||
|
||||
def forward_prepare(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
forward_batch: ForwardBatch,
|
||||
):
|
||||
if hidden_states.shape[0] == 0:
|
||||
return hidden_states, forward_batch, None
|
||||
qkv, _ = self.qkv_proj(hidden_states)
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
q, k = self.rotary_emb(positions, q, k)
|
||||
inner_state = q, k, v, forward_batch
|
||||
return None, forward_batch, inner_state
|
||||
|
||||
def forward_core(self, intermediate_state):
|
||||
hidden_states, forward_batch, inner_state = intermediate_state
|
||||
if inner_state is None:
|
||||
return hidden_states
|
||||
attn_output = self.attn(*inner_state, sk=self.sinks)
|
||||
output, _ = self.o_proj(attn_output)
|
||||
return output
|
||||
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
forward_batch: ForwardBatch,
|
||||
) -> torch.Tensor:
|
||||
s = self.forward_prepare(
|
||||
positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
forward_batch=forward_batch,
|
||||
)
|
||||
return self.forward_core(s)
|
||||
|
||||
|
||||
class GptOssDecoderLayer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: GptOssConfig,
|
||||
layer_id: int,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
sliding_window_size: int | None = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.hidden_size = config.hidden_size
|
||||
rope_theta = getattr(config, "rope_theta", 10000)
|
||||
rope_scaling = getattr(config, "rope_scaling", None)
|
||||
max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
|
||||
head_dim = getattr(
|
||||
config, "head_dim", config.hidden_size // config.num_attention_heads
|
||||
)
|
||||
rms_norm_eps = config.rms_norm_eps
|
||||
attention_bias = config.attention_bias
|
||||
|
||||
if sliding_window_size is None:
|
||||
self.sliding_window_size = get_attention_sliding_window_size(self.config)
|
||||
else:
|
||||
self.sliding_window_size = sliding_window_size
|
||||
|
||||
self.self_attn = GptOssAttention(
|
||||
hidden_size=self.hidden_size,
|
||||
num_heads=config.num_attention_heads,
|
||||
num_kv_heads=config.num_key_value_heads,
|
||||
layer_id=layer_id,
|
||||
rope_theta=rope_theta,
|
||||
rope_scaling=rope_scaling,
|
||||
max_position_embeddings=max_position_embeddings,
|
||||
head_dim=head_dim,
|
||||
rms_norm_eps=rms_norm_eps,
|
||||
attention_bias=attention_bias,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("self_attn", prefix),
|
||||
sliding_window_size=self.sliding_window_size,
|
||||
layer_type=config.layer_types[layer_id],
|
||||
params_dtype=config.torch_dtype,
|
||||
)
|
||||
|
||||
self.layer_id = layer_id
|
||||
|
||||
self.attn_tp_size = get_attention_tp_size()
|
||||
self.attn_tp_rank = get_attention_tp_rank()
|
||||
self.local_dp_size = get_local_attention_dp_size()
|
||||
|
||||
# GptOss all layers are sparse and have no nextn now
|
||||
self.is_layer_sparse = True
|
||||
is_previous_layer_sparse = True
|
||||
|
||||
self.layer_scatter_modes = LayerScatterModes.init_new(
|
||||
layer_id=layer_id,
|
||||
num_layers=config.num_hidden_layers,
|
||||
is_layer_sparse=self.is_layer_sparse,
|
||||
is_previous_layer_sparse=is_previous_layer_sparse,
|
||||
)
|
||||
|
||||
if self.is_layer_sparse:
|
||||
self.mlp = GptOssSparseMoeBlock(
|
||||
layer_id=self.layer_id,
|
||||
config=config,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("mlp", prefix),
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"Dense MLP is not implemented for GptOssDecoderLayer. "
|
||||
"Please use GptOssSparseMoeBlock instead."
|
||||
)
|
||||
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.post_attention_layernorm = RMSNorm(
|
||||
config.hidden_size, eps=config.rms_norm_eps
|
||||
)
|
||||
|
||||
self.layer_communicator = LayerCommunicator(
|
||||
layer_scatter_modes=self.layer_scatter_modes,
|
||||
input_layernorm=self.input_layernorm,
|
||||
post_attention_layernorm=self.post_attention_layernorm,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
forward_batch: ForwardBatch,
|
||||
residual: Optional[torch.Tensor],
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
hidden_states, residual = self.layer_communicator.prepare_attn(
|
||||
hidden_states, residual, forward_batch
|
||||
)
|
||||
|
||||
if hidden_states.shape[0] != 0:
|
||||
hidden_states = self.self_attn(
|
||||
positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
forward_batch=forward_batch,
|
||||
)
|
||||
|
||||
hidden_states, residual = self.layer_communicator.prepare_mlp(
|
||||
hidden_states, residual, forward_batch
|
||||
)
|
||||
|
||||
hidden_states = self.mlp(hidden_states, forward_batch)
|
||||
|
||||
hidden_states, residual = self.layer_communicator.postprocess_layer(
|
||||
hidden_states, residual, forward_batch
|
||||
)
|
||||
|
||||
return hidden_states, residual
|
||||
|
||||
|
||||
class GptOssModel(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
decoder_layer_type: type[nn.Module] = GptOssDecoderLayer,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.padding_idx = config.pad_token_id
|
||||
self.vocab_size = config.vocab_size
|
||||
self.pp_group = get_pp_group()
|
||||
|
||||
if self.pp_group.is_first_rank:
|
||||
self.embed_tokens = VocabParallelEmbedding(
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
enable_tp=not global_server_args_dict["enable_dp_attention"],
|
||||
prefix=add_prefix("embed_tokens", prefix),
|
||||
)
|
||||
else:
|
||||
self.embed_tokens = PPMissingLayer()
|
||||
|
||||
# Use the provided decoder layer type or default to GptOssDecoderLayer
|
||||
decoder_layer_type = decoder_layer_type or GptOssDecoderLayer
|
||||
self.layers, self.start_layer, self.end_layer = make_layers(
|
||||
config.num_hidden_layers,
|
||||
lambda idx, prefix: decoder_layer_type(
|
||||
layer_id=idx,
|
||||
config=config,
|
||||
quant_config=quant_config,
|
||||
prefix=prefix,
|
||||
),
|
||||
pp_rank=self.pp_group.rank_in_group,
|
||||
pp_size=self.pp_group.world_size,
|
||||
prefix=add_prefix("layers", prefix),
|
||||
)
|
||||
if self.pp_group.is_last_rank:
|
||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
else:
|
||||
self.norm = PPMissingLayer(return_tuple=True)
|
||||
|
||||
self.layers_to_capture = []
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
forward_batch: ForwardBatch,
|
||||
input_embeds: torch.Tensor = None,
|
||||
pp_proxy_tensors: Optional[PPProxyTensors] = None,
|
||||
) -> Union[torch.Tensor, PPProxyTensors]:
|
||||
if self.pp_group.is_first_rank:
|
||||
if input_embeds is None:
|
||||
hidden_states = self.embed_tokens(input_ids)
|
||||
else:
|
||||
hidden_states = input_embeds
|
||||
residual = None
|
||||
else:
|
||||
assert pp_proxy_tensors is not None
|
||||
hidden_states = pp_proxy_tensors["hidden_states"]
|
||||
residual = pp_proxy_tensors["residual"]
|
||||
|
||||
aux_hidden_states = []
|
||||
for i in range(self.start_layer, self.end_layer):
|
||||
with get_global_expert_distribution_recorder().with_current_layer(i):
|
||||
if i in self.layers_to_capture:
|
||||
aux_hidden_states.append(hidden_states + residual)
|
||||
layer = self.layers[i]
|
||||
hidden_states, residual = layer(
|
||||
positions, hidden_states, forward_batch, residual
|
||||
)
|
||||
if not self.pp_group.is_last_rank:
|
||||
return PPProxyTensors(
|
||||
{
|
||||
"hidden_states": hidden_states,
|
||||
"residual": residual,
|
||||
}
|
||||
)
|
||||
else:
|
||||
if hidden_states.shape[0] != 0:
|
||||
if residual is None:
|
||||
hidden_states = self.norm(hidden_states)
|
||||
else:
|
||||
hidden_states, _ = self.norm(hidden_states, residual)
|
||||
if len(aux_hidden_states) == 0:
|
||||
return hidden_states
|
||||
|
||||
return hidden_states, aux_hidden_states
|
||||
|
||||
|
||||
class GptOssForCausalLM(nn.Module):
|
||||
fall_back_to_pt_during_load = False
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: GptOssConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.pp_group = get_pp_group()
|
||||
self.config = config
|
||||
self.quant_config = quant_config
|
||||
self.model = GptOssModel(
|
||||
config, quant_config, prefix=add_prefix("model", prefix)
|
||||
)
|
||||
self.lm_head = ParallelLMHead(
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("lm_head", prefix),
|
||||
use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"],
|
||||
)
|
||||
self.logits_processor = LogitsProcessor(config)
|
||||
self.capture_aux_hidden_states = False
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
forward_batch: ForwardBatch,
|
||||
input_embeds: torch.Tensor = None,
|
||||
pp_proxy_tensors: Optional[PPProxyTensors] = None,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.model(
|
||||
input_ids,
|
||||
positions,
|
||||
forward_batch,
|
||||
input_embeds,
|
||||
pp_proxy_tensors=pp_proxy_tensors,
|
||||
)
|
||||
|
||||
aux_hidden_states = None
|
||||
if self.capture_aux_hidden_states:
|
||||
hidden_states, aux_hidden_states = hidden_states
|
||||
|
||||
if self.pp_group.is_last_rank:
|
||||
return self.logits_processor(
|
||||
input_ids,
|
||||
hidden_states,
|
||||
self.lm_head,
|
||||
forward_batch,
|
||||
aux_hidden_states,
|
||||
)
|
||||
else:
|
||||
return hidden_states
|
||||
|
||||
@property
|
||||
def start_layer(self):
|
||||
return self.model.start_layer
|
||||
|
||||
@property
|
||||
def end_layer(self):
|
||||
return self.model.end_layer
|
||||
|
||||
def _get_default_weight_mapping(self):
|
||||
"""Generate default weight name mapping for GptOss safetensors."""
|
||||
weight_mapping = {}
|
||||
|
||||
# Map router weights to gate
|
||||
weight_mapping["embedding.weight"] = "model.embed_tokens.weight"
|
||||
weight_mapping["unembedding.weight"] = "lm_head.weight"
|
||||
weight_mapping["norm.scale"] = "model.norm.weight"
|
||||
for layer_id in range(self.config.num_hidden_layers):
|
||||
weight_mapping[f"block.{layer_id}.attn.q_proj.weight"] = (
|
||||
f"model.layers.{layer_id}.self_attn.q_proj.weight"
|
||||
)
|
||||
weight_mapping[f"block.{layer_id}.attn.q_proj.bias"] = (
|
||||
f"model.layers.{layer_id}.self_attn.q_proj.bias"
|
||||
)
|
||||
|
||||
weight_mapping[f"block.{layer_id}.attn.k_proj.weight"] = (
|
||||
f"model.layers.{layer_id}.self_attn.k_proj.weight"
|
||||
)
|
||||
weight_mapping[f"block.{layer_id}.attn.k_proj.bias"] = (
|
||||
f"model.layers.{layer_id}.self_attn.k_proj.bias"
|
||||
)
|
||||
|
||||
weight_mapping[f"block.{layer_id}.attn.v_proj.weight"] = (
|
||||
f"model.layers.{layer_id}.self_attn.v_proj.weight"
|
||||
)
|
||||
weight_mapping[f"block.{layer_id}.attn.v_proj.bias"] = (
|
||||
f"model.layers.{layer_id}.self_attn.v_proj.bias"
|
||||
)
|
||||
|
||||
weight_mapping[f"block.{layer_id}.attn.out.weight"] = (
|
||||
f"model.layers.{layer_id}.self_attn.o_proj.weight"
|
||||
)
|
||||
weight_mapping[f"block.{layer_id}.attn.out.bias"] = (
|
||||
f"model.layers.{layer_id}.self_attn.o_proj.bias"
|
||||
)
|
||||
weight_mapping[f"block.{layer_id}.attn.sinks"] = (
|
||||
f"model.layers.{layer_id}.self_attn.sinks"
|
||||
)
|
||||
weight_mapping[f"block.{layer_id}.attn.norm.scale"] = (
|
||||
f"model.layers.{layer_id}.input_layernorm.weight"
|
||||
)
|
||||
|
||||
weight_mapping[f"block.{layer_id}.mlp.gate.weight"] = (
|
||||
f"model.layers.{layer_id}.mlp.router.weight"
|
||||
)
|
||||
weight_mapping[f"block.{layer_id}.mlp.gate.bias"] = (
|
||||
f"model.layers.{layer_id}.mlp.router.bias"
|
||||
)
|
||||
weight_mapping[f"block.{layer_id}.mlp.norm.scale"] = (
|
||||
f"model.layers.{layer_id}.post_attention_layernorm.weight"
|
||||
)
|
||||
weight_mapping[f"block.{layer_id}.mlp.experts.gate_up_proj"] = (
|
||||
f"model.layers.{layer_id}.mlp.experts.gate_up_proj"
|
||||
)
|
||||
weight_mapping[f"block.{layer_id}.mlp.gate_up_proj_bias"] = (
|
||||
f"model.layers.{layer_id}.mlp.experts.gate_up_proj_bias"
|
||||
)
|
||||
weight_mapping[f"block.{layer_id}.mlp.down_proj"] = (
|
||||
f"model.layers.{layer_id}.mlp.experts.mlp2_weight"
|
||||
)
|
||||
weight_mapping[f"block.{layer_id}.mlp.down_proj_bias"] = (
|
||||
f"model.layers.{layer_id}.mlp.experts.mlp2_bias"
|
||||
)
|
||||
|
||||
return weight_mapping
|
||||
|
||||
def load_weights(
|
||||
self,
|
||||
weights: Iterable[Tuple[str, torch.Tensor]],
|
||||
is_nextn: bool = False,
|
||||
weight_name_mapping: dict = None,
|
||||
):
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
if is_nextn:
|
||||
logging.warning(
|
||||
"Loading weights for nextn is currently not supported in GptOssForCausalLM. "
|
||||
)
|
||||
return
|
||||
weights = _canonicalize_weights(self.config, weights)
|
||||
weights = sorted(weights, key=lambda x: x[0]) # Sort by name for consistency
|
||||
|
||||
new_weights = []
|
||||
for name, p in weights:
|
||||
if "qkv.weight" in name:
|
||||
q_proj, k_proj, v_proj = p.split(
|
||||
[
|
||||
self.config.num_attention_heads * self.config.head_dim,
|
||||
self.config.num_key_value_heads * self.config.head_dim,
|
||||
self.config.num_key_value_heads * self.config.head_dim,
|
||||
],
|
||||
dim=0,
|
||||
)
|
||||
new_weights.append(
|
||||
(f"{name.replace('qkv.weight', 'q_proj.weight')}", q_proj)
|
||||
)
|
||||
new_weights.append(
|
||||
(f"{name.replace('qkv.weight', 'k_proj.weight')}", k_proj)
|
||||
)
|
||||
new_weights.append(
|
||||
(f"{name.replace('qkv.weight', 'v_proj.weight')}", v_proj)
|
||||
)
|
||||
elif "qkv.bias" in name:
|
||||
q_bias, k_bias, v_bias = p.split(
|
||||
[
|
||||
self.config.num_attention_heads * self.config.head_dim,
|
||||
self.config.num_key_value_heads * self.config.head_dim,
|
||||
self.config.num_key_value_heads * self.config.head_dim,
|
||||
],
|
||||
dim=0,
|
||||
)
|
||||
new_weights.append(
|
||||
(f"{name.replace('qkv.bias', 'q_proj.bias')}", q_bias)
|
||||
)
|
||||
new_weights.append(
|
||||
(f"{name.replace('qkv.bias', 'k_proj.bias')}", k_bias)
|
||||
)
|
||||
new_weights.append(
|
||||
(f"{name.replace('qkv.bias', 'v_proj.bias')}", v_bias)
|
||||
)
|
||||
else:
|
||||
new_weights.append((name, p))
|
||||
weights = new_weights
|
||||
|
||||
# Use provided weight name mapping if available, otherwise use default
|
||||
if weight_name_mapping is None:
|
||||
weight_name_mapping = self._get_default_weight_mapping()
|
||||
else:
|
||||
# Merge with default mapping
|
||||
default_mapping = self._get_default_weight_mapping()
|
||||
default_mapping.update(weight_name_mapping)
|
||||
weight_name_mapping = default_mapping
|
||||
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
("qkv_proj", "q_proj", "q"),
|
||||
("qkv_proj", "k_proj", "k"),
|
||||
("qkv_proj", "v_proj", "v"),
|
||||
]
|
||||
|
||||
expert_params_mapping = get_moe_impl_class().make_expert_params_mapping_fused(
|
||||
ckpt_gate_up_proj_name="gate_up_proj",
|
||||
ckpt_down_proj_name="down_proj",
|
||||
ckpt_gate_up_proj_bias_name="gate_up_proj_bias",
|
||||
ckpt_down_proj_bias_name="down_proj_bias",
|
||||
)
|
||||
|
||||
params_dict = dict(self.named_parameters())
|
||||
params_checker = {k: False for k, v in params_dict.items()}
|
||||
for name, loaded_weight in weights:
|
||||
loaded_weight = _WeightCreator.maybe_materialize(loaded_weight)
|
||||
|
||||
# Apply weight name mapping if provided
|
||||
if weight_name_mapping and name in weight_name_mapping:
|
||||
name = weight_name_mapping[name]
|
||||
|
||||
layer_id = get_layer_id(name)
|
||||
if (
|
||||
layer_id is not None
|
||||
and hasattr(self.model, "start_layer")
|
||||
and (
|
||||
layer_id < self.model.start_layer
|
||||
or layer_id >= self.model.end_layer
|
||||
)
|
||||
):
|
||||
continue
|
||||
|
||||
if "rotary_emb.inv_freq" in name:
|
||||
continue
|
||||
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||
if weight_name not in name:
|
||||
continue
|
||||
if "mlp.experts" in name:
|
||||
continue
|
||||
|
||||
name = name.replace(weight_name, param_name)
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
if name not in params_dict:
|
||||
continue
|
||||
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
params_checker[name] = True
|
||||
break
|
||||
else:
|
||||
for mapping in expert_params_mapping:
|
||||
param_name, weight_name, shard_id = mapping
|
||||
if weight_name not in name:
|
||||
continue
|
||||
name = name.replace(weight_name, param_name)
|
||||
if name not in params_dict:
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
if "bias" not in name:
|
||||
loaded_weight = loaded_weight.transpose(-2, -1)
|
||||
if "w2_weight_bias" in name and get_moe_tensor_parallel_rank() != 0:
|
||||
loaded_weight = loaded_weight.zero_()
|
||||
|
||||
weight_loader(
|
||||
param,
|
||||
loaded_weight,
|
||||
name,
|
||||
shard_id=shard_id,
|
||||
)
|
||||
params_checker[name] = True
|
||||
break
|
||||
else:
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
if name not in params_dict:
|
||||
continue
|
||||
if name in params_dict.keys():
|
||||
param = params_dict[name]
|
||||
if "sinks" in name:
|
||||
start = tp_rank * param.numel()
|
||||
param.data.copy_(
|
||||
loaded_weight[start : start + param.numel()]
|
||||
)
|
||||
else:
|
||||
weight_loader = getattr(
|
||||
param, "weight_loader", default_weight_loader
|
||||
)
|
||||
weight_loader(param, loaded_weight)
|
||||
params_checker[name] = True
|
||||
else:
|
||||
logger.warning(f"Parameter {name} not found in params_dict")
|
||||
|
||||
not_loaded_params = [k for k, v in params_checker.items() if not v]
|
||||
if tp_rank == 0:
|
||||
if len(not_loaded_params) > 0:
|
||||
raise Exception(f"Not all parameters loaded: {not_loaded_params}")
|
||||
else:
|
||||
logging.info("All parameters loaded successfully.")
|
||||
|
||||
self.routed_experts_weights_of_layer = {
|
||||
layer_id: self.model.layers[layer_id].mlp.get_moe_weights()
|
||||
for layer_id in range(self.start_layer, self.end_layer)
|
||||
if isinstance(self.model.layers[layer_id].mlp, GptOssSparseMoeBlock)
|
||||
}
|
||||
|
||||
def get_embed_and_head(self):
|
||||
return self.model.embed_tokens.weight, self.lm_head.weight
|
||||
|
||||
def set_embed_and_head(self, embed, head):
|
||||
del self.model.embed_tokens.weight
|
||||
del self.lm_head.weight
|
||||
self.model.embed_tokens.weight = embed
|
||||
self.lm_head.weight = head
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
def set_eagle3_layers_to_capture(self, layer_ids: Optional[List[int]] = None):
|
||||
if not self.pp_group.is_last_rank:
|
||||
return
|
||||
|
||||
if layer_ids is None:
|
||||
self.capture_aux_hidden_states = True
|
||||
num_layers = self.config.num_hidden_layers
|
||||
self.model.layers_to_capture = [2, num_layers // 2, num_layers - 3]
|
||||
else:
|
||||
self.capture_aux_hidden_states = True
|
||||
# we plus 1 here because in sglang, for the ith layer, it takes the output
|
||||
# of the (i-1)th layer as aux hidden state
|
||||
self.model.layers_to_capture = [val + 1 for val in layer_ids]
|
||||
|
||||
@classmethod
|
||||
def get_model_config_for_expert_location(cls, config):
|
||||
return ModelConfigForExpertLocation(
|
||||
num_layers=config.num_hidden_layers,
|
||||
num_logical_experts=config.num_local_experts,
|
||||
num_groups=None,
|
||||
)
|
||||
|
||||
def get_attention_sliding_window_size(self):
|
||||
return get_attention_sliding_window_size(self.config)
|
||||
|
||||
|
||||
def _canonicalize_weights(config, weights_in: Iterable[Tuple[str, torch.Tensor]]):
|
||||
weights_out_dict = dict(weights_in)
|
||||
|
||||
for layer_id in range(config.num_hidden_layers):
|
||||
for name_chunk in ["mlp1_weight", "mlp2_weight"]:
|
||||
name_prefix = f"block.{layer_id}.mlp.{name_chunk}"
|
||||
w_blocks = weights_out_dict.pop(f"{name_prefix}.blocks", None)
|
||||
w_scales = weights_out_dict.pop(f"{name_prefix}.scales", None)
|
||||
if w_blocks is not None:
|
||||
weights_out_dict[name_prefix] = _WeightCreator(
|
||||
partial(
|
||||
_dequant_mlp_weight,
|
||||
debug_name=name_prefix,
|
||||
w_blocks=w_blocks,
|
||||
w_scales=w_scales,
|
||||
)
|
||||
)
|
||||
|
||||
return list(weights_out_dict.items())
|
||||
|
||||
|
||||
def _dequant_mlp_weight(debug_name, w_blocks, w_scales):
|
||||
if get_tensor_model_parallel_rank() == 0:
|
||||
logger.info(f"Dequantize {debug_name} start")
|
||||
|
||||
original_device = w_blocks.device
|
||||
|
||||
w_blocks = w_blocks.cuda()
|
||||
w_scales = w_scales.cuda()
|
||||
|
||||
w_bf16 = dequant_mxfp4(w_block=w_blocks, w_scale=w_scales, out_dtype=torch.bfloat16)
|
||||
w_bf16 = w_bf16.transpose(-2, -1).contiguous()
|
||||
|
||||
if get_tensor_model_parallel_rank() == 0:
|
||||
logger.info(
|
||||
f"Dequantize {debug_name} end {w_blocks.shape=} {w_scales.shape=} {w_bf16.shape=}"
|
||||
)
|
||||
|
||||
return w_bf16.to(original_device)
|
||||
|
||||
|
||||
class _WeightCreator:
|
||||
def __init__(self, fn):
|
||||
self._fn = fn
|
||||
|
||||
@staticmethod
|
||||
def maybe_materialize(obj):
|
||||
if isinstance(obj, _WeightCreator):
|
||||
output = obj._fn()
|
||||
obj._fn = None
|
||||
return output
|
||||
|
||||
return obj
|
||||
|
||||
|
||||
EntryClass = GptOssForCausalLM
|
||||
@@ -457,6 +457,10 @@ class ServerArgs:
|
||||
raise ValueError(
|
||||
"trtllm_mla backend does not support speculative decoding yet."
|
||||
)
|
||||
model_arch = self.get_hf_config().architectures[0]
|
||||
if model_arch in ["GptOssForCausalLM"]:
|
||||
self.attention_backend = "triton"
|
||||
self.enable_triton_kernel_moe = True
|
||||
|
||||
# Set page size
|
||||
if self.page_size is None:
|
||||
|
||||
Reference in New Issue
Block a user