Add initial support for gpt-oss (#8824)
This commit is contained in:
@@ -88,6 +88,7 @@ class TritonAttnBackend(AttentionBackend):
|
|||||||
self.window_kv_indptr = torch.zeros_like(kv_indptr_buf)
|
self.window_kv_indptr = torch.zeros_like(kv_indptr_buf)
|
||||||
|
|
||||||
self.req_to_token = model_runner.req_to_token_pool.req_to_token
|
self.req_to_token = model_runner.req_to_token_pool.req_to_token
|
||||||
|
self.token_to_kv_pool_allocator = model_runner.token_to_kv_pool_allocator
|
||||||
|
|
||||||
if not self.skip_prefill:
|
if not self.skip_prefill:
|
||||||
self.qo_indptr = torch.zeros(
|
self.qo_indptr = torch.zeros(
|
||||||
@@ -197,6 +198,7 @@ class TritonAttnBackend(AttentionBackend):
|
|||||||
forward_batch.req_pool_indices,
|
forward_batch.req_pool_indices,
|
||||||
bs,
|
bs,
|
||||||
self.device,
|
self.device,
|
||||||
|
self.token_to_kv_pool_allocator,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
window_num_kv_splits = torch.empty(
|
window_num_kv_splits = torch.empty(
|
||||||
@@ -225,7 +227,6 @@ class TritonAttnBackend(AttentionBackend):
|
|||||||
mask_indptr = None
|
mask_indptr = None
|
||||||
max_extend_len = None
|
max_extend_len = None
|
||||||
elif forward_batch.forward_mode.is_target_verify():
|
elif forward_batch.forward_mode.is_target_verify():
|
||||||
# TODO: Support sliding window in spec inference
|
|
||||||
bs = len(forward_batch.req_pool_indices)
|
bs = len(forward_batch.req_pool_indices)
|
||||||
qo_indptr = torch.arange(
|
qo_indptr = torch.arange(
|
||||||
0,
|
0,
|
||||||
@@ -250,6 +251,20 @@ class TritonAttnBackend(AttentionBackend):
|
|||||||
self.req_to_token.stride(0),
|
self.req_to_token.stride(0),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if self.sliding_window_size is not None and self.sliding_window_size > 0:
|
||||||
|
window_kv_indptr, window_kv_indices, window_kv_lens = (
|
||||||
|
update_sliding_window_buffer(
|
||||||
|
self.window_kv_indptr,
|
||||||
|
self.req_to_token,
|
||||||
|
self.sliding_window_size,
|
||||||
|
forward_batch.seq_lens,
|
||||||
|
forward_batch.req_pool_indices,
|
||||||
|
bs,
|
||||||
|
self.device,
|
||||||
|
self.token_to_kv_pool_allocator,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
custom_mask = spec_info.custom_mask
|
custom_mask = spec_info.custom_mask
|
||||||
seq_mask_len = self.num_draft_tokens * (
|
seq_mask_len = self.num_draft_tokens * (
|
||||||
forward_batch.seq_lens + self.num_draft_tokens
|
forward_batch.seq_lens + self.num_draft_tokens
|
||||||
@@ -308,6 +323,7 @@ class TritonAttnBackend(AttentionBackend):
|
|||||||
forward_batch.req_pool_indices,
|
forward_batch.req_pool_indices,
|
||||||
bs,
|
bs,
|
||||||
self.device,
|
self.device,
|
||||||
|
self.token_to_kv_pool_allocator,
|
||||||
)
|
)
|
||||||
|
|
||||||
qo_indptr = self.qo_indptr
|
qo_indptr = self.qo_indptr
|
||||||
@@ -423,14 +439,17 @@ class TritonAttnBackend(AttentionBackend):
|
|||||||
):
|
):
|
||||||
window_kv_indices = self.cuda_graph_window_kv_indices
|
window_kv_indices = self.cuda_graph_window_kv_indices
|
||||||
window_num_kv_splits = self.cuda_graph_window_num_kv_splits
|
window_num_kv_splits = self.cuda_graph_window_num_kv_splits
|
||||||
window_kv_indptr, _ = update_sliding_window_buffer_cuda_graph(
|
window_kv_indptr, window_kv_indices, _ = (
|
||||||
self.window_kv_indptr,
|
update_sliding_window_buffer_cuda_graph(
|
||||||
window_kv_indices,
|
self.window_kv_indptr,
|
||||||
self.req_to_token,
|
window_kv_indices,
|
||||||
self.sliding_window_size,
|
self.req_to_token,
|
||||||
seq_lens[:bs],
|
self.sliding_window_size,
|
||||||
req_pool_indices,
|
seq_lens[:bs],
|
||||||
bs,
|
req_pool_indices,
|
||||||
|
bs,
|
||||||
|
self.token_to_kv_pool_allocator,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
kv_indptr, kv_indices = spec_info.kv_indptr, spec_info.kv_indices
|
kv_indptr, kv_indices = spec_info.kv_indptr, spec_info.kv_indices
|
||||||
@@ -464,6 +483,22 @@ class TritonAttnBackend(AttentionBackend):
|
|||||||
self.req_to_token.stride(0),
|
self.req_to_token.stride(0),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if self.sliding_window_size is not None and self.sliding_window_size > 0:
|
||||||
|
window_kv_indices = self.cuda_graph_window_kv_indices
|
||||||
|
window_num_kv_splits = self.cuda_graph_window_num_kv_splits
|
||||||
|
window_kv_indptr, window_kv_indices, _ = (
|
||||||
|
update_sliding_window_buffer_cuda_graph(
|
||||||
|
self.window_kv_indptr,
|
||||||
|
window_kv_indices,
|
||||||
|
self.req_to_token,
|
||||||
|
self.sliding_window_size,
|
||||||
|
seq_lens,
|
||||||
|
req_pool_indices,
|
||||||
|
bs,
|
||||||
|
self.token_to_kv_pool_allocator,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
custom_mask = self.cuda_graph_custom_mask
|
custom_mask = self.cuda_graph_custom_mask
|
||||||
custom_mask[: spec_info.custom_mask.shape[0]] = spec_info.custom_mask
|
custom_mask[: spec_info.custom_mask.shape[0]] = spec_info.custom_mask
|
||||||
seq_mask_len = self.num_draft_tokens * (seq_lens + self.num_draft_tokens)
|
seq_mask_len = self.num_draft_tokens * (seq_lens + self.num_draft_tokens)
|
||||||
@@ -557,7 +592,7 @@ class TritonAttnBackend(AttentionBackend):
|
|||||||
):
|
):
|
||||||
window_num_kv_splits = self.cuda_graph_window_num_kv_splits
|
window_num_kv_splits = self.cuda_graph_window_num_kv_splits
|
||||||
window_kv_indices = self.cuda_graph_window_kv_indices
|
window_kv_indices = self.cuda_graph_window_kv_indices
|
||||||
_, window_kv_lens = update_sliding_window_buffer_cuda_graph(
|
_, _, window_kv_lens = update_sliding_window_buffer_cuda_graph(
|
||||||
self.window_kv_indptr,
|
self.window_kv_indptr,
|
||||||
window_kv_indices,
|
window_kv_indices,
|
||||||
self.req_to_token,
|
self.req_to_token,
|
||||||
@@ -565,6 +600,7 @@ class TritonAttnBackend(AttentionBackend):
|
|||||||
seq_lens[:bs],
|
seq_lens[:bs],
|
||||||
req_pool_indices[:bs],
|
req_pool_indices[:bs],
|
||||||
bs,
|
bs,
|
||||||
|
self.token_to_kv_pool_allocator,
|
||||||
)
|
)
|
||||||
self.get_num_kv_splits(
|
self.get_num_kv_splits(
|
||||||
window_num_kv_splits[:num_token], window_kv_lens[:bs]
|
window_num_kv_splits[:num_token], window_kv_lens[:bs]
|
||||||
@@ -599,6 +635,19 @@ class TritonAttnBackend(AttentionBackend):
|
|||||||
kv_indices,
|
kv_indices,
|
||||||
self.req_to_token.stride(0),
|
self.req_to_token.stride(0),
|
||||||
)
|
)
|
||||||
|
if self.sliding_window_size is not None and self.sliding_window_size > 0:
|
||||||
|
window_num_kv_splits = self.cuda_graph_window_num_kv_splits
|
||||||
|
window_kv_indices = self.cuda_graph_window_kv_indices
|
||||||
|
_, _, window_kv_lens = update_sliding_window_buffer_cuda_graph(
|
||||||
|
self.window_kv_indptr,
|
||||||
|
window_kv_indices,
|
||||||
|
self.req_to_token,
|
||||||
|
self.sliding_window_size,
|
||||||
|
seq_lens,
|
||||||
|
req_pool_indices,
|
||||||
|
bs,
|
||||||
|
self.token_to_kv_pool_allocator,
|
||||||
|
)
|
||||||
custom_mask = self.cuda_graph_custom_mask
|
custom_mask = self.cuda_graph_custom_mask
|
||||||
custom_mask[: spec_info.custom_mask.shape[0]] = spec_info.custom_mask
|
custom_mask[: spec_info.custom_mask.shape[0]] = spec_info.custom_mask
|
||||||
seq_mask_len = self.num_draft_tokens * (seq_lens + self.num_draft_tokens)
|
seq_mask_len = self.num_draft_tokens * (seq_lens + self.num_draft_tokens)
|
||||||
@@ -637,6 +686,7 @@ class TritonAttnBackend(AttentionBackend):
|
|||||||
layer: RadixAttention,
|
layer: RadixAttention,
|
||||||
forward_batch: ForwardBatch,
|
forward_batch: ForwardBatch,
|
||||||
save_kv_cache=True,
|
save_kv_cache=True,
|
||||||
|
sk=None,
|
||||||
):
|
):
|
||||||
# TODO: reuse the buffer across layers
|
# TODO: reuse the buffer across layers
|
||||||
if layer.qk_head_dim != layer.v_head_dim:
|
if layer.qk_head_dim != layer.v_head_dim:
|
||||||
@@ -680,7 +730,8 @@ class TritonAttnBackend(AttentionBackend):
|
|||||||
self.forward_metadata.max_extend_len,
|
self.forward_metadata.max_extend_len,
|
||||||
layer.scaling,
|
layer.scaling,
|
||||||
layer.logit_cap,
|
layer.logit_cap,
|
||||||
sliding_window_size,
|
sliding_window_size=sliding_window_size,
|
||||||
|
sk=sk,
|
||||||
)
|
)
|
||||||
return o
|
return o
|
||||||
|
|
||||||
@@ -692,6 +743,7 @@ class TritonAttnBackend(AttentionBackend):
|
|||||||
layer: RadixAttention,
|
layer: RadixAttention,
|
||||||
forward_batch: ForwardBatch,
|
forward_batch: ForwardBatch,
|
||||||
save_kv_cache=True,
|
save_kv_cache=True,
|
||||||
|
sk=None,
|
||||||
):
|
):
|
||||||
# During torch.compile, there is a bug in rotary_emb that causes the
|
# During torch.compile, there is a bug in rotary_emb that causes the
|
||||||
# output value to have a 3D tensor shape. This reshapes the output correctly.
|
# output value to have a 3D tensor shape. This reshapes the output correctly.
|
||||||
@@ -728,6 +780,7 @@ class TritonAttnBackend(AttentionBackend):
|
|||||||
self.max_kv_splits,
|
self.max_kv_splits,
|
||||||
layer.scaling,
|
layer.scaling,
|
||||||
layer.logit_cap,
|
layer.logit_cap,
|
||||||
|
sk=sk,
|
||||||
)
|
)
|
||||||
return o
|
return o
|
||||||
|
|
||||||
@@ -932,10 +985,11 @@ def update_sliding_window_buffer(
|
|||||||
req_pool_indices,
|
req_pool_indices,
|
||||||
bs,
|
bs,
|
||||||
device,
|
device,
|
||||||
|
token_to_kv_pool_allocator=None,
|
||||||
):
|
):
|
||||||
window_kv_lens = torch.minimum(
|
window_kv_lens = torch.minimum(
|
||||||
seq_lens,
|
seq_lens,
|
||||||
torch.tensor(sliding_window_size + 1),
|
torch.tensor(sliding_window_size),
|
||||||
)
|
)
|
||||||
window_kv_indptr[1 : bs + 1] = torch.cumsum(window_kv_lens, dim=0)
|
window_kv_indptr[1 : bs + 1] = torch.cumsum(window_kv_lens, dim=0)
|
||||||
window_kv_indptr = window_kv_indptr[: bs + 1]
|
window_kv_indptr = window_kv_indptr[: bs + 1]
|
||||||
@@ -952,6 +1006,14 @@ def update_sliding_window_buffer(
|
|||||||
window_kv_indices,
|
window_kv_indices,
|
||||||
req_to_token.stride(0),
|
req_to_token.stride(0),
|
||||||
)
|
)
|
||||||
|
# full to swa index mapping
|
||||||
|
if hasattr(token_to_kv_pool_allocator, "translate_loc_from_full_to_swa"):
|
||||||
|
kv_last_index = window_kv_indptr[-1]
|
||||||
|
window_kv_indices[:kv_last_index] = (
|
||||||
|
token_to_kv_pool_allocator.translate_loc_from_full_to_swa(
|
||||||
|
window_kv_indices[:kv_last_index]
|
||||||
|
)
|
||||||
|
)
|
||||||
return window_kv_indptr, window_kv_indices, window_kv_lens
|
return window_kv_indptr, window_kv_indices, window_kv_lens
|
||||||
|
|
||||||
|
|
||||||
@@ -963,10 +1025,11 @@ def update_sliding_window_buffer_cuda_graph(
|
|||||||
seq_lens,
|
seq_lens,
|
||||||
req_pool_indices,
|
req_pool_indices,
|
||||||
bs,
|
bs,
|
||||||
|
token_to_kv_pool_allocator=None,
|
||||||
):
|
):
|
||||||
window_kv_lens = torch.minimum(
|
window_kv_lens = torch.minimum(
|
||||||
seq_lens,
|
seq_lens,
|
||||||
torch.tensor(sliding_window_size + 1),
|
torch.tensor(sliding_window_size),
|
||||||
)
|
)
|
||||||
window_kv_indptr[1 : bs + 1] = torch.cumsum(window_kv_lens, dim=0)
|
window_kv_indptr[1 : bs + 1] = torch.cumsum(window_kv_lens, dim=0)
|
||||||
window_kv_indptr = window_kv_indptr[: bs + 1]
|
window_kv_indptr = window_kv_indptr[: bs + 1]
|
||||||
@@ -980,4 +1043,12 @@ def update_sliding_window_buffer_cuda_graph(
|
|||||||
window_kv_indices,
|
window_kv_indices,
|
||||||
req_to_token.stride(0),
|
req_to_token.stride(0),
|
||||||
)
|
)
|
||||||
return window_kv_indptr, window_kv_lens
|
# full to swa index mapping
|
||||||
|
if hasattr(token_to_kv_pool_allocator, "translate_loc_from_full_to_swa"):
|
||||||
|
kv_last_index = window_kv_indptr[-1]
|
||||||
|
window_kv_indices[:kv_last_index] = (
|
||||||
|
token_to_kv_pool_allocator.translate_loc_from_full_to_swa(
|
||||||
|
window_kv_indices[:kv_last_index]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return window_kv_indptr, window_kv_indices, window_kv_lens
|
||||||
|
|||||||
@@ -495,6 +495,7 @@ def _fwd_kernel_stage2(
|
|||||||
O,
|
O,
|
||||||
kv_indptr,
|
kv_indptr,
|
||||||
num_kv_splits,
|
num_kv_splits,
|
||||||
|
sk_ptr,
|
||||||
stride_mid_ob,
|
stride_mid_ob,
|
||||||
stride_mid_oh,
|
stride_mid_oh,
|
||||||
stride_mid_os,
|
stride_mid_os,
|
||||||
@@ -504,6 +505,7 @@ def _fwd_kernel_stage2(
|
|||||||
MIN_BLOCK_KV: tl.constexpr,
|
MIN_BLOCK_KV: tl.constexpr,
|
||||||
BLOCK_DV: tl.constexpr,
|
BLOCK_DV: tl.constexpr,
|
||||||
Lv: tl.constexpr,
|
Lv: tl.constexpr,
|
||||||
|
HAS_SK: tl.constexpr,
|
||||||
):
|
):
|
||||||
cur_batch = tl.program_id(0)
|
cur_batch = tl.program_id(0)
|
||||||
cur_head = tl.program_id(1)
|
cur_head = tl.program_id(1)
|
||||||
@@ -545,6 +547,10 @@ def _fwd_kernel_stage2(
|
|||||||
e_sum = e_sum * old_scale + exp_logic
|
e_sum = e_sum * old_scale + exp_logic
|
||||||
e_max = n_e_max
|
e_max = n_e_max
|
||||||
|
|
||||||
|
if HAS_SK:
|
||||||
|
cur_sk = tl.load(sk_ptr + cur_head)
|
||||||
|
e_sum += tl.exp(cur_sk - e_max)
|
||||||
|
|
||||||
tl.store(
|
tl.store(
|
||||||
O + cur_batch * stride_obs + cur_head * stride_oh + offs_d,
|
O + cur_batch * stride_obs + cur_head * stride_oh + offs_d,
|
||||||
acc / e_sum,
|
acc / e_sum,
|
||||||
@@ -561,12 +567,14 @@ def _decode_softmax_reducev_fwd(
|
|||||||
kv_indptr,
|
kv_indptr,
|
||||||
num_kv_splits,
|
num_kv_splits,
|
||||||
max_kv_splits,
|
max_kv_splits,
|
||||||
|
sk=None,
|
||||||
):
|
):
|
||||||
batch, head_num = q.shape[0], q.shape[1]
|
batch, head_num = q.shape[0], q.shape[1]
|
||||||
Lv = v_buffer.shape[-1]
|
Lv = v_buffer.shape[-1]
|
||||||
BLOCK_DV = triton.next_power_of_2(Lv)
|
BLOCK_DV = triton.next_power_of_2(Lv)
|
||||||
|
|
||||||
MAX_KV_SPLITS = max_kv_splits
|
MAX_KV_SPLITS = max_kv_splits
|
||||||
|
HAS_SK = sk is not None
|
||||||
|
|
||||||
extra_kargs = {}
|
extra_kargs = {}
|
||||||
if _is_hip:
|
if _is_hip:
|
||||||
@@ -581,6 +589,7 @@ def _decode_softmax_reducev_fwd(
|
|||||||
o,
|
o,
|
||||||
kv_indptr,
|
kv_indptr,
|
||||||
num_kv_splits,
|
num_kv_splits,
|
||||||
|
sk,
|
||||||
logits.stride(0),
|
logits.stride(0),
|
||||||
logits.stride(1),
|
logits.stride(1),
|
||||||
logits.stride(2),
|
logits.stride(2),
|
||||||
@@ -590,6 +599,7 @@ def _decode_softmax_reducev_fwd(
|
|||||||
MIN_BLOCK_KV=_MIN_BLOCK_KV,
|
MIN_BLOCK_KV=_MIN_BLOCK_KV,
|
||||||
BLOCK_DV=BLOCK_DV,
|
BLOCK_DV=BLOCK_DV,
|
||||||
Lv=Lv,
|
Lv=Lv,
|
||||||
|
HAS_SK=HAS_SK,
|
||||||
num_warps=4,
|
num_warps=4,
|
||||||
num_stages=2,
|
num_stages=2,
|
||||||
**extra_kargs,
|
**extra_kargs,
|
||||||
@@ -609,6 +619,7 @@ def decode_attention_fwd_normal(
|
|||||||
max_kv_splits,
|
max_kv_splits,
|
||||||
sm_scale,
|
sm_scale,
|
||||||
logit_cap=0.0,
|
logit_cap=0.0,
|
||||||
|
sk=None,
|
||||||
):
|
):
|
||||||
_decode_att_m_fwd(
|
_decode_att_m_fwd(
|
||||||
q,
|
q,
|
||||||
@@ -632,6 +643,7 @@ def decode_attention_fwd_normal(
|
|||||||
kv_indptr,
|
kv_indptr,
|
||||||
num_kv_splits,
|
num_kv_splits,
|
||||||
max_kv_splits,
|
max_kv_splits,
|
||||||
|
sk,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -648,6 +660,7 @@ def decode_attention_fwd_grouped(
|
|||||||
max_kv_splits,
|
max_kv_splits,
|
||||||
sm_scale,
|
sm_scale,
|
||||||
logit_cap=0.0,
|
logit_cap=0.0,
|
||||||
|
sk=None,
|
||||||
):
|
):
|
||||||
_decode_grouped_att_m_fwd(
|
_decode_grouped_att_m_fwd(
|
||||||
q,
|
q,
|
||||||
@@ -671,6 +684,7 @@ def decode_attention_fwd_grouped(
|
|||||||
kv_indptr,
|
kv_indptr,
|
||||||
num_kv_splits,
|
num_kv_splits,
|
||||||
max_kv_splits,
|
max_kv_splits,
|
||||||
|
sk,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -687,6 +701,7 @@ def decode_attention_fwd(
|
|||||||
max_kv_splits,
|
max_kv_splits,
|
||||||
sm_scale,
|
sm_scale,
|
||||||
logit_cap=0.0,
|
logit_cap=0.0,
|
||||||
|
sk=None,
|
||||||
):
|
):
|
||||||
assert max_kv_splits == attn_logits.shape[2]
|
assert max_kv_splits == attn_logits.shape[2]
|
||||||
assert q.shape[0] <= kv_indptr.shape[0] - 1
|
assert q.shape[0] <= kv_indptr.shape[0] - 1
|
||||||
@@ -709,6 +724,7 @@ def decode_attention_fwd(
|
|||||||
max_kv_splits,
|
max_kv_splits,
|
||||||
sm_scale,
|
sm_scale,
|
||||||
logit_cap=logit_cap,
|
logit_cap=logit_cap,
|
||||||
|
sk=sk,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# GQA/MQA/MLA
|
# GQA/MQA/MLA
|
||||||
@@ -725,4 +741,5 @@ def decode_attention_fwd(
|
|||||||
max_kv_splits,
|
max_kv_splits,
|
||||||
sm_scale,
|
sm_scale,
|
||||||
logit_cap=logit_cap,
|
logit_cap=logit_cap,
|
||||||
|
sk=sk,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -51,6 +51,7 @@ def _fwd_kernel(
|
|||||||
kv_indices,
|
kv_indices,
|
||||||
mask_ptr,
|
mask_ptr,
|
||||||
mask_indptr,
|
mask_indptr,
|
||||||
|
sk_ptr,
|
||||||
sm_scale,
|
sm_scale,
|
||||||
kv_group_num,
|
kv_group_num,
|
||||||
stride_qbs,
|
stride_qbs,
|
||||||
@@ -78,6 +79,7 @@ def _fwd_kernel(
|
|||||||
IS_CAUSAL: tl.constexpr,
|
IS_CAUSAL: tl.constexpr,
|
||||||
SKIP_PREFIX_CUSTOM_MASK: tl.constexpr,
|
SKIP_PREFIX_CUSTOM_MASK: tl.constexpr,
|
||||||
STORE_TRANSPOSE: tl.constexpr,
|
STORE_TRANSPOSE: tl.constexpr,
|
||||||
|
HAS_SK: tl.constexpr,
|
||||||
):
|
):
|
||||||
cur_seq = tl.program_id(0)
|
cur_seq = tl.program_id(0)
|
||||||
cur_head = tl.program_id(1)
|
cur_head = tl.program_id(1)
|
||||||
@@ -178,13 +180,17 @@ def _fwd_kernel(
|
|||||||
final_mask &= custom_mask
|
final_mask &= custom_mask
|
||||||
if SLIDING_WINDOW_SIZE > 0:
|
if SLIDING_WINDOW_SIZE > 0:
|
||||||
# Add mask where q_id <= kv_id + sliding_window_size
|
# Add mask where q_id <= kv_id + sliding_window_size
|
||||||
window_mask = (cur_block_m * BLOCK_M + offs_m[:, None]) <= (
|
# q_id = prefix_len + cur_m, kv_id = cur_n
|
||||||
start_n + offs_n[None, :] + SLIDING_WINDOW_SIZE
|
window_mask = (
|
||||||
)
|
cur_seq_len_prefix + cur_block_m * BLOCK_M + offs_m[:, None]
|
||||||
|
) <= (start_n + offs_n[None, :] + SLIDING_WINDOW_SIZE)
|
||||||
final_mask &= window_mask
|
final_mask &= window_mask
|
||||||
qk = tl.where(final_mask, qk, float("-inf"))
|
qk = tl.where(final_mask, qk, float("-inf"))
|
||||||
|
|
||||||
n_e_max = tl.maximum(tl.max(qk, 1), e_max)
|
row_max = tl.max(qk, 1)
|
||||||
|
row_max_fixed = tl.where(row_max == float("-inf"), -1e20, row_max)
|
||||||
|
n_e_max = tl.maximum(row_max_fixed, e_max)
|
||||||
|
|
||||||
re_scale = tl.exp(e_max - n_e_max)
|
re_scale = tl.exp(e_max - n_e_max)
|
||||||
p = tl.exp(qk - n_e_max[:, None])
|
p = tl.exp(qk - n_e_max[:, None])
|
||||||
deno = deno * re_scale + tl.sum(p, 1)
|
deno = deno * re_scale + tl.sum(p, 1)
|
||||||
@@ -242,6 +248,7 @@ def _fwd_kernel(
|
|||||||
if logit_cap > 0:
|
if logit_cap > 0:
|
||||||
qk = logit_cap * tanh(qk / logit_cap)
|
qk = logit_cap * tanh(qk / logit_cap)
|
||||||
|
|
||||||
|
final_mask = mask_m[:, None] & mask_n[None, :]
|
||||||
if USE_CUSTOM_MASK:
|
if USE_CUSTOM_MASK:
|
||||||
custom_mask = tl.load(
|
custom_mask = tl.load(
|
||||||
mask_ptr
|
mask_ptr
|
||||||
@@ -254,18 +261,30 @@ def _fwd_kernel(
|
|||||||
other=0,
|
other=0,
|
||||||
)
|
)
|
||||||
custom_mask &= mask_m[:, None] & mask_n[None, :]
|
custom_mask &= mask_m[:, None] & mask_n[None, :]
|
||||||
qk = tl.where(custom_mask, qk, float("-inf"))
|
final_mask &= custom_mask
|
||||||
elif IS_CAUSAL:
|
elif IS_CAUSAL:
|
||||||
mask_causual = (cur_block_m * BLOCK_M + offs_m[:, None]) >= (
|
mask_causual = (cur_block_m * BLOCK_M + offs_m[:, None]) >= (
|
||||||
start_n + offs_n[None, :]
|
start_n + offs_n[None, :]
|
||||||
)
|
)
|
||||||
mask_causual &= mask_m[:, None] & mask_n[None, :]
|
mask_causual &= mask_m[:, None] & mask_n[None, :]
|
||||||
qk = tl.where(mask_causual, qk, float("-inf"))
|
final_mask &= mask_causual
|
||||||
else:
|
else:
|
||||||
mask_non_causal = mask_m[:, None] & mask_n[None, :]
|
mask_non_causal = mask_m[:, None] & mask_n[None, :]
|
||||||
qk = tl.where(mask_non_causal, qk, float("-inf"))
|
final_mask &= mask_non_causal
|
||||||
|
|
||||||
|
if SLIDING_WINDOW_SIZE > 0:
|
||||||
|
# Add mask where q_id <= kv_id + sliding_window_size
|
||||||
|
window_mask = (cur_block_m * BLOCK_M + offs_m[:, None]) <= (
|
||||||
|
start_n + offs_n[None, :] + SLIDING_WINDOW_SIZE
|
||||||
|
)
|
||||||
|
final_mask &= window_mask
|
||||||
|
|
||||||
|
qk = tl.where(final_mask, qk, float("-inf"))
|
||||||
|
|
||||||
|
row_max = tl.max(qk, 1)
|
||||||
|
row_max_fixed = tl.where(row_max == float("-inf"), -1e20, row_max)
|
||||||
|
n_e_max = tl.maximum(row_max_fixed, e_max)
|
||||||
|
|
||||||
n_e_max = tl.maximum(tl.max(qk, 1), e_max)
|
|
||||||
re_scale = tl.exp(e_max - n_e_max)
|
re_scale = tl.exp(e_max - n_e_max)
|
||||||
p = tl.exp(qk - n_e_max[:, None])
|
p = tl.exp(qk - n_e_max[:, None])
|
||||||
deno = deno * re_scale + tl.sum(p, 1)
|
deno = deno * re_scale + tl.sum(p, 1)
|
||||||
@@ -283,6 +302,10 @@ def _fwd_kernel(
|
|||||||
|
|
||||||
e_max = n_e_max
|
e_max = n_e_max
|
||||||
|
|
||||||
|
if HAS_SK:
|
||||||
|
cur_sk = tl.load(sk_ptr + cur_head)
|
||||||
|
deno += tl.exp(cur_sk - e_max)
|
||||||
|
|
||||||
offs_o = (
|
offs_o = (
|
||||||
(cur_seq_extend_start_idx + cur_block_m * BLOCK_M + offs_m[:, None])
|
(cur_seq_extend_start_idx + cur_block_m * BLOCK_M + offs_m[:, None])
|
||||||
* stride_obs
|
* stride_obs
|
||||||
@@ -321,6 +344,7 @@ def extend_attention_fwd(
|
|||||||
logit_cap=0.0,
|
logit_cap=0.0,
|
||||||
skip_prefix_custom_mask=True,
|
skip_prefix_custom_mask=True,
|
||||||
sliding_window_size=-1,
|
sliding_window_size=-1,
|
||||||
|
sk=None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
q_extend, k_extend, v_extend, o_extend: contiguous tensors
|
q_extend, k_extend, v_extend, o_extend: contiguous tensors
|
||||||
@@ -386,6 +410,8 @@ def extend_attention_fwd(
|
|||||||
# Skip custom mask for prefix part
|
# Skip custom mask for prefix part
|
||||||
SKIP_PREFIX_CUSTOM_MASK = skip_prefix_custom_mask
|
SKIP_PREFIX_CUSTOM_MASK = skip_prefix_custom_mask
|
||||||
|
|
||||||
|
HAS_SK = sk is not None
|
||||||
|
|
||||||
grid = (batch_size, head_num, triton.cdiv(max_len_extend, BLOCK_M))
|
grid = (batch_size, head_num, triton.cdiv(max_len_extend, BLOCK_M))
|
||||||
num_stages = 1
|
num_stages = 1
|
||||||
|
|
||||||
@@ -405,6 +431,7 @@ def extend_attention_fwd(
|
|||||||
kv_indices,
|
kv_indices,
|
||||||
custom_mask,
|
custom_mask,
|
||||||
mask_indptr,
|
mask_indptr,
|
||||||
|
sk,
|
||||||
sm_scale,
|
sm_scale,
|
||||||
kv_group_num,
|
kv_group_num,
|
||||||
q_extend.stride(0),
|
q_extend.stride(0),
|
||||||
@@ -431,6 +458,7 @@ def extend_attention_fwd(
|
|||||||
USE_CUSTOM_MASK=USE_CUSTOM_MASK,
|
USE_CUSTOM_MASK=USE_CUSTOM_MASK,
|
||||||
IS_CAUSAL=is_causal,
|
IS_CAUSAL=is_causal,
|
||||||
SKIP_PREFIX_CUSTOM_MASK=SKIP_PREFIX_CUSTOM_MASK,
|
SKIP_PREFIX_CUSTOM_MASK=SKIP_PREFIX_CUSTOM_MASK,
|
||||||
|
HAS_SK=HAS_SK,
|
||||||
STORE_TRANSPOSE=_is_hip,
|
STORE_TRANSPOSE=_is_hip,
|
||||||
num_warps=num_warps,
|
num_warps=num_warps,
|
||||||
num_stages=num_stages,
|
num_stages=num_stages,
|
||||||
|
|||||||
@@ -1191,11 +1191,6 @@ class RowParallelLinear(LinearBase):
|
|||||||
else self.weight_loader
|
else self.weight_loader
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
if not reduce_results and (bias and not skip_bias_add):
|
|
||||||
raise ValueError(
|
|
||||||
"When not reduce the results, adding bias to the "
|
|
||||||
"results can lead to incorrect results"
|
|
||||||
)
|
|
||||||
|
|
||||||
if bias:
|
if bias:
|
||||||
self.bias = Parameter(torch.empty(self.output_size, dtype=params_dtype))
|
self.bias = Parameter(torch.empty(self.output_size, dtype=params_dtype))
|
||||||
|
|||||||
@@ -134,6 +134,10 @@ class FusedMoE(torch.nn.Module):
|
|||||||
no_combine: bool = False,
|
no_combine: bool = False,
|
||||||
routed_scaling_factor: Optional[float] = None,
|
routed_scaling_factor: Optional[float] = None,
|
||||||
enable_flashinfer_cutlass_moe: Optional[bool] = False,
|
enable_flashinfer_cutlass_moe: Optional[bool] = False,
|
||||||
|
activation_alpha: Optional[float] = None,
|
||||||
|
swiglu_limit: Optional[float] = None,
|
||||||
|
use_weight_loader_fused: bool = False,
|
||||||
|
with_bias=False,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@@ -148,6 +152,10 @@ class FusedMoE(torch.nn.Module):
|
|||||||
self.expert_map_cpu = None
|
self.expert_map_cpu = None
|
||||||
self.expert_map_gpu = None
|
self.expert_map_gpu = None
|
||||||
|
|
||||||
|
# For activation
|
||||||
|
self.activation_alpha = activation_alpha
|
||||||
|
self.swiglu_limit = swiglu_limit
|
||||||
|
|
||||||
if enable_flashinfer_cutlass_moe and quant_config is None:
|
if enable_flashinfer_cutlass_moe and quant_config is None:
|
||||||
logger.warning("Disable flashinfer MoE when quantization config is None.")
|
logger.warning("Disable flashinfer MoE when quantization config is None.")
|
||||||
enable_flashinfer_cutlass_moe = False
|
enable_flashinfer_cutlass_moe = False
|
||||||
@@ -191,7 +199,7 @@ class FusedMoE(torch.nn.Module):
|
|||||||
|
|
||||||
if quant_config is None:
|
if quant_config is None:
|
||||||
self.quant_method: Optional[QuantizeMethodBase] = UnquantizedFusedMoEMethod(
|
self.quant_method: Optional[QuantizeMethodBase] = UnquantizedFusedMoEMethod(
|
||||||
self.use_triton_kernels
|
self.use_triton_kernels, with_bias=with_bias
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self.quant_method = quant_config.get_quant_method(self, prefix)
|
self.quant_method = quant_config.get_quant_method(self, prefix)
|
||||||
@@ -206,7 +214,12 @@ class FusedMoE(torch.nn.Module):
|
|||||||
intermediate_size=self.intermediate_size_per_partition,
|
intermediate_size=self.intermediate_size_per_partition,
|
||||||
intermediate_size_per_partition=self.intermediate_size_per_partition,
|
intermediate_size_per_partition=self.intermediate_size_per_partition,
|
||||||
params_dtype=params_dtype,
|
params_dtype=params_dtype,
|
||||||
weight_loader=self.weight_loader,
|
weight_loader=(
|
||||||
|
self.weight_loader
|
||||||
|
if not use_weight_loader_fused
|
||||||
|
else self.weight_loader_fused
|
||||||
|
),
|
||||||
|
with_bias=with_bias,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _load_per_tensor_weight_scale(
|
def _load_per_tensor_weight_scale(
|
||||||
@@ -234,6 +247,7 @@ class FusedMoE(torch.nn.Module):
|
|||||||
shard_id: str,
|
shard_id: str,
|
||||||
loaded_weight: torch.Tensor,
|
loaded_weight: torch.Tensor,
|
||||||
tp_rank: int,
|
tp_rank: int,
|
||||||
|
is_bias: bool = False,
|
||||||
):
|
):
|
||||||
# Load grouped weight scales for group quantization
|
# Load grouped weight scales for group quantization
|
||||||
# or model weights
|
# or model weights
|
||||||
@@ -244,14 +258,16 @@ class FusedMoE(torch.nn.Module):
|
|||||||
loaded_weight=loaded_weight,
|
loaded_weight=loaded_weight,
|
||||||
expert_data=expert_data,
|
expert_data=expert_data,
|
||||||
tp_rank=tp_rank,
|
tp_rank=tp_rank,
|
||||||
|
is_bias=is_bias,
|
||||||
)
|
)
|
||||||
elif shard_id in ("w1", "w3"):
|
elif shard_id in ("w1", "w3", "w13"):
|
||||||
self._load_w13(
|
self._load_w13(
|
||||||
shard_id=shard_id,
|
shard_id=shard_id,
|
||||||
shard_dim=shard_dim,
|
shard_dim=shard_dim,
|
||||||
loaded_weight=loaded_weight,
|
loaded_weight=loaded_weight,
|
||||||
expert_data=expert_data,
|
expert_data=expert_data,
|
||||||
tp_rank=tp_rank,
|
tp_rank=tp_rank,
|
||||||
|
is_bias=is_bias,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _load_per_channel_weight_scale(
|
def _load_per_channel_weight_scale(
|
||||||
@@ -281,17 +297,30 @@ class FusedMoE(torch.nn.Module):
|
|||||||
shard_id: str,
|
shard_id: str,
|
||||||
loaded_weight: torch.Tensor,
|
loaded_weight: torch.Tensor,
|
||||||
tp_rank: int,
|
tp_rank: int,
|
||||||
|
is_bias: bool = False,
|
||||||
):
|
):
|
||||||
|
|
||||||
# Index the loaded weight for tp sharding.
|
# Index the loaded weight for tp sharding.
|
||||||
# gate_up_proj: "MergedColumnParallel", so tp sharding on output_dim
|
# gate_up_proj: "MergedColumnParallel", so tp sharding on output_dim
|
||||||
shard_size = expert_data.shape[shard_dim] // 2
|
assert shard_id in {"w1", "w3", "w13"}
|
||||||
|
|
||||||
|
if is_bias:
|
||||||
|
# if this weight is a bias, the last dimension must be the sharded dimension
|
||||||
|
shard_dim = -1
|
||||||
|
|
||||||
|
if shard_id in {"w1", "w3"}:
|
||||||
|
# non-fused version
|
||||||
|
shard_size = expert_data.shape[shard_dim] // 2
|
||||||
|
elif shard_id in {"w13"}:
|
||||||
|
# fused version
|
||||||
|
shard_size = expert_data.shape[shard_dim]
|
||||||
|
else:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
# Narrow parameter and load.
|
# Narrow parameter and load.
|
||||||
# w1, gate_proj: Load into first logical weight of w13.
|
# w1, gate_proj: Load into first logical weight of w13.
|
||||||
# w3, up_proj: Load into second logical weight of w13.
|
# w3, up_proj: Load into second logical weight of w13.
|
||||||
# trtllm cutlass kernel assumes differently
|
# trtllm cutlass kernel assumes differently
|
||||||
assert shard_id in ("w1", "w3")
|
|
||||||
switch_w13 = getattr(self.quant_method, "load_up_proj_weight_first", False)
|
switch_w13 = getattr(self.quant_method, "load_up_proj_weight_first", False)
|
||||||
if (switch_w13 and shard_id == "w1") or (not switch_w13 and shard_id == "w3"):
|
if (switch_w13 and shard_id == "w1") or (not switch_w13 and shard_id == "w3"):
|
||||||
start = shard_size
|
start = shard_size
|
||||||
@@ -310,7 +339,8 @@ class FusedMoE(torch.nn.Module):
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
if not self.use_presharded_weights:
|
if not self.use_presharded_weights:
|
||||||
if self.use_triton_kernels:
|
if not is_bias and self.use_triton_kernels:
|
||||||
|
# do not transpose for bias
|
||||||
loaded_weight = loaded_weight.transpose(-2, -1)
|
loaded_weight = loaded_weight.transpose(-2, -1)
|
||||||
loaded_weight = loaded_weight.narrow(
|
loaded_weight = loaded_weight.narrow(
|
||||||
shard_dim, shard_size * tp_rank, shard_size
|
shard_dim, shard_size * tp_rank, shard_size
|
||||||
@@ -326,6 +356,7 @@ class FusedMoE(torch.nn.Module):
|
|||||||
shard_id: str,
|
shard_id: str,
|
||||||
loaded_weight: torch.Tensor,
|
loaded_weight: torch.Tensor,
|
||||||
tp_rank: int,
|
tp_rank: int,
|
||||||
|
is_bias: bool = False,
|
||||||
):
|
):
|
||||||
"""Load w2 weights for down projection.
|
"""Load w2 weights for down projection.
|
||||||
|
|
||||||
@@ -356,7 +387,14 @@ class FusedMoE(torch.nn.Module):
|
|||||||
# Index the loaded weight for tp sharding.
|
# Index the loaded weight for tp sharding.
|
||||||
# down_proj: "RowParallel" so tp sharding on input_dim
|
# down_proj: "RowParallel" so tp sharding on input_dim
|
||||||
# Narrow parameter and load.
|
# Narrow parameter and load.
|
||||||
shard_size = expert_data.shape[shard_dim]
|
if is_bias:
|
||||||
|
# this expert_data is a bias, not weight,
|
||||||
|
# for w2_bias in TP, it does not need to be sharded
|
||||||
|
shard_size = expert_data.shape[-1]
|
||||||
|
else:
|
||||||
|
# this parameter is a weight matrix
|
||||||
|
# for w2 in TP, it shards the input_features, i.e., shard_dim=2
|
||||||
|
shard_size = expert_data.shape[shard_dim]
|
||||||
|
|
||||||
if _is_cpu:
|
if _is_cpu:
|
||||||
expert_data, loaded_weight = narrow_padded_param_and_loaded_weight(
|
expert_data, loaded_weight = narrow_padded_param_and_loaded_weight(
|
||||||
@@ -369,7 +407,7 @@ class FusedMoE(torch.nn.Module):
|
|||||||
not self.use_presharded_weights,
|
not self.use_presharded_weights,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
if not self.use_presharded_weights:
|
if not is_bias and not self.use_presharded_weights:
|
||||||
if self.use_triton_kernels:
|
if self.use_triton_kernels:
|
||||||
loaded_weight = loaded_weight.transpose(-2, -1)
|
loaded_weight = loaded_weight.transpose(-2, -1)
|
||||||
if shard_size * tp_rank + shard_size > loaded_weight.shape[shard_dim]:
|
if shard_size * tp_rank + shard_size > loaded_weight.shape[shard_dim]:
|
||||||
@@ -658,6 +696,68 @@ class FusedMoE(torch.nn.Module):
|
|||||||
)
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
|
def weight_loader_fused(
|
||||||
|
self,
|
||||||
|
param: torch.nn.Parameter,
|
||||||
|
loaded_weight: torch.Tensor,
|
||||||
|
weight_name: str,
|
||||||
|
shard_id: str,
|
||||||
|
) -> None:
|
||||||
|
tp_rank = self.moe_tp_rank
|
||||||
|
|
||||||
|
# compressed-tensors checkpoints with packed weights are stored flipped
|
||||||
|
# TODO: check self.quant_method.quant_config.quant_format
|
||||||
|
# against known CompressionFormat enum values that have this quality
|
||||||
|
loaded_weight = (
|
||||||
|
loaded_weight.t().contiguous()
|
||||||
|
if (
|
||||||
|
self.quant_method.__class__.__name__
|
||||||
|
== "CompressedTensorsWNA16MoEMethod"
|
||||||
|
)
|
||||||
|
else loaded_weight
|
||||||
|
)
|
||||||
|
|
||||||
|
if shard_id not in ("w13", "w2"):
|
||||||
|
raise ValueError(f"shard_id must be ['w13','w2'] but " f"got {shard_id}.")
|
||||||
|
|
||||||
|
# Fetch the dim to shard the parameter/loaded weight
|
||||||
|
# based on the shard id. This will be whatever
|
||||||
|
# dimension intermediate_size is used.
|
||||||
|
SHARD_ID_TO_SHARDED_DIM = {"w13": 1, "w2": 2}
|
||||||
|
SHARD_ID_TO_SHARDED_DIM_TRANSPOSE = {"w13": 2, "w2": 1}
|
||||||
|
|
||||||
|
expert_data = param.data
|
||||||
|
is_bias = expert_data.dim() == 2
|
||||||
|
|
||||||
|
# is_transposed: if the dim to shard the weight
|
||||||
|
# should be flipped. Required by GPTQ, compressed-tensors
|
||||||
|
# should be whatever dimension intermediate_size is
|
||||||
|
is_transposed = getattr(param, "is_transposed", False)
|
||||||
|
|
||||||
|
if self.use_triton_kernels:
|
||||||
|
is_transposed = True
|
||||||
|
shard_dim = (
|
||||||
|
SHARD_ID_TO_SHARDED_DIM[shard_id]
|
||||||
|
if not is_transposed
|
||||||
|
else SHARD_ID_TO_SHARDED_DIM_TRANSPOSE[shard_id]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Case model weights
|
||||||
|
if "weight" in weight_name:
|
||||||
|
self._load_model_weight_or_group_weight_scale(
|
||||||
|
shard_id=shard_id,
|
||||||
|
shard_dim=shard_dim,
|
||||||
|
loaded_weight=loaded_weight,
|
||||||
|
expert_data=expert_data,
|
||||||
|
tp_rank=tp_rank,
|
||||||
|
is_bias=is_bias,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
else:
|
||||||
|
logging.warning(
|
||||||
|
f"Unsupported weight_name {weight_name} for FusedMoE weight_loader_fused. Nothing is loaded."
|
||||||
|
)
|
||||||
|
|
||||||
def forward(self, hidden_states: torch.Tensor, topk_output: StandardTopKOutput):
|
def forward(self, hidden_states: torch.Tensor, topk_output: StandardTopKOutput):
|
||||||
assert self.quant_method is not None
|
assert self.quant_method is not None
|
||||||
|
|
||||||
@@ -673,6 +773,12 @@ class FusedMoE(torch.nn.Module):
|
|||||||
|
|
||||||
# Matrix multiply.
|
# Matrix multiply.
|
||||||
with use_symmetric_memory(get_tp_group()) as sm:
|
with use_symmetric_memory(get_tp_group()) as sm:
|
||||||
|
kwargs = {}
|
||||||
|
if self.activation_alpha is not None:
|
||||||
|
kwargs["activation_alpha"] = self.activation_alpha
|
||||||
|
if self.swiglu_limit is not None:
|
||||||
|
kwargs["swiglu_limit"] = self.swiglu_limit
|
||||||
|
|
||||||
final_hidden_states = self.quant_method.apply(
|
final_hidden_states = self.quant_method.apply(
|
||||||
layer=self,
|
layer=self,
|
||||||
x=hidden_states,
|
x=hidden_states,
|
||||||
@@ -691,6 +797,7 @@ class FusedMoE(torch.nn.Module):
|
|||||||
== "ModelOptNvFp4FusedMoEMethod"
|
== "ModelOptNvFp4FusedMoEMethod"
|
||||||
else {}
|
else {}
|
||||||
),
|
),
|
||||||
|
**kwargs,
|
||||||
)
|
)
|
||||||
sm.tag(final_hidden_states)
|
sm.tag(final_hidden_states)
|
||||||
|
|
||||||
@@ -728,6 +835,25 @@ class FusedMoE(torch.nn.Module):
|
|||||||
]
|
]
|
||||||
]
|
]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def make_expert_params_mapping_fused(
|
||||||
|
cls,
|
||||||
|
ckpt_gate_up_proj_name: str,
|
||||||
|
ckpt_down_proj_name: str,
|
||||||
|
ckpt_gate_up_proj_bias_name: str,
|
||||||
|
ckpt_down_proj_bias_name: str,
|
||||||
|
):
|
||||||
|
return [
|
||||||
|
("experts.w13_weight", f"experts.{ckpt_gate_up_proj_name}", "w13"),
|
||||||
|
(
|
||||||
|
"experts.w13_weight_bias",
|
||||||
|
f"experts.{ckpt_gate_up_proj_bias_name}",
|
||||||
|
"w13",
|
||||||
|
),
|
||||||
|
("experts.w2_weight", f"experts.{ckpt_down_proj_name}", "w2"),
|
||||||
|
("experts.w2_weight_bias", f"experts.{ckpt_down_proj_bias_name}", "w2"),
|
||||||
|
]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def make_expert_input_scale_params_mapping(
|
def make_expert_input_scale_params_mapping(
|
||||||
cls,
|
cls,
|
||||||
|
|||||||
@@ -6,15 +6,50 @@ from typing import TYPE_CHECKING, Optional
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
from sgl_kernel import gelu_and_mul, silu_and_mul
|
from sgl_kernel import gelu_and_mul, silu_and_mul
|
||||||
from triton_kernels.matmul_ogs import matmul_ogs
|
from triton_kernels.matmul_ogs import (
|
||||||
|
FlexCtx,
|
||||||
|
FnSpecs,
|
||||||
|
FusedActivation,
|
||||||
|
PrecisionConfig,
|
||||||
|
matmul_ogs,
|
||||||
|
)
|
||||||
|
from triton_kernels.numerics import InFlexData
|
||||||
from triton_kernels.routing import GatherIndx, RoutingData, ScatterIndx
|
from triton_kernels.routing import GatherIndx, RoutingData, ScatterIndx
|
||||||
|
from triton_kernels.swiglu import swiglu_fn
|
||||||
from sglang.srt.utils import direct_register_custom_op
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from sglang.srt.layers.moe.topk import TopKOutput
|
from sglang.srt.layers.moe.topk import TopKOutput
|
||||||
|
|
||||||
|
|
||||||
|
def quantize(w, dtype, dev, **opt):
|
||||||
|
if dtype == "bf16":
|
||||||
|
return w.to(torch.bfloat16), InFlexData()
|
||||||
|
elif dtype == "fp8":
|
||||||
|
wq = w.to(torch.float8_e4m3fn).transpose(-1, -2).contiguous().transpose(-1, -2)
|
||||||
|
return (
|
||||||
|
wq,
|
||||||
|
InFlexData(dtype=wq.dtype, scale=w.abs().max().unsqueeze(0)),
|
||||||
|
MicroscalingCtx(),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
assert dtype == "mx4", f"{dtype=}"
|
||||||
|
swizzle_mx_scale = opt["swizzle_mx_scale"]
|
||||||
|
swizzle_axis = 2 if swizzle_mx_scale else None
|
||||||
|
w = w.to(torch.bfloat16)
|
||||||
|
w, mx_scales, weight_scale_shape = downcast_to_mxfp(
|
||||||
|
w, torch.uint8, axis=1, swizzle_axis=swizzle_axis
|
||||||
|
)
|
||||||
|
return (
|
||||||
|
w,
|
||||||
|
InFlexData(),
|
||||||
|
MicroscalingCtx(
|
||||||
|
weight_scale=mx_scales,
|
||||||
|
swizzle_mx=swizzle_mx_scale,
|
||||||
|
actual_weight_scale_shape=weight_scale_shape,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def triton_kernel_moe_forward(
|
def triton_kernel_moe_forward(
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
w1: torch.Tensor,
|
w1: torch.Tensor,
|
||||||
@@ -146,3 +181,143 @@ def triton_kernel_fused_experts(
|
|||||||
)
|
)
|
||||||
|
|
||||||
return intermediate_cache3
|
return intermediate_cache3
|
||||||
|
|
||||||
|
|
||||||
|
def triton_kernel_moe_with_bias_forward(
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
w1: torch.Tensor,
|
||||||
|
b1: torch.Tensor,
|
||||||
|
w2: torch.Tensor,
|
||||||
|
b2: torch.Tensor,
|
||||||
|
topk_output: TopKOutput,
|
||||||
|
inplace: bool = False,
|
||||||
|
activation: str = "silu",
|
||||||
|
use_fp8_w8a8: bool = False,
|
||||||
|
per_channel_quant: bool = False,
|
||||||
|
global_num_experts: int = -1,
|
||||||
|
expert_map: Optional[torch.Tensor] = None,
|
||||||
|
w1_scale: Optional[torch.Tensor] = None,
|
||||||
|
w2_scale: Optional[torch.Tensor] = None,
|
||||||
|
a1_scale: Optional[torch.Tensor] = None,
|
||||||
|
a2_scale: Optional[torch.Tensor] = None,
|
||||||
|
block_shape: Optional[list[int]] = None,
|
||||||
|
activation_alpha: Optional[float] = None,
|
||||||
|
swiglu_limit: Optional[int] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
assert topk_output.format.is_triton_kernel()
|
||||||
|
routing_data, gather_idx, scatter_idx = topk_output
|
||||||
|
|
||||||
|
return triton_kernel_fused_experts_with_bias(
|
||||||
|
hidden_states,
|
||||||
|
w1,
|
||||||
|
b1,
|
||||||
|
w2,
|
||||||
|
b2,
|
||||||
|
routing_data,
|
||||||
|
gather_idx,
|
||||||
|
scatter_idx,
|
||||||
|
inplace=inplace,
|
||||||
|
activation=activation,
|
||||||
|
use_fp8_w8a8=use_fp8_w8a8,
|
||||||
|
per_channel_quant=per_channel_quant,
|
||||||
|
global_num_experts=global_num_experts,
|
||||||
|
expert_map=expert_map,
|
||||||
|
w1_scale=w1_scale,
|
||||||
|
w2_scale=w2_scale,
|
||||||
|
a1_scale=a1_scale,
|
||||||
|
a2_scale=a2_scale,
|
||||||
|
block_shape=block_shape,
|
||||||
|
activation_alpha=activation_alpha,
|
||||||
|
swiglu_limit=swiglu_limit,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def triton_kernel_fused_experts_with_bias(
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
w1: torch.Tensor,
|
||||||
|
b1: torch.Tensor,
|
||||||
|
w2: torch.Tensor,
|
||||||
|
b2: torch.Tensor,
|
||||||
|
routing_data: RoutingData,
|
||||||
|
gather_indx: GatherIndx,
|
||||||
|
scatter_indx: ScatterIndx,
|
||||||
|
inplace: bool = False,
|
||||||
|
activation: str = "silu",
|
||||||
|
use_fp8_w8a8: bool = False,
|
||||||
|
per_channel_quant: bool = False,
|
||||||
|
global_num_experts: int = -1,
|
||||||
|
expert_map: Optional[torch.Tensor] = None,
|
||||||
|
w1_scale: Optional[torch.Tensor] = None,
|
||||||
|
w2_scale: Optional[torch.Tensor] = None,
|
||||||
|
a1_scale: Optional[torch.Tensor] = None,
|
||||||
|
a2_scale: Optional[torch.Tensor] = None,
|
||||||
|
block_shape: Optional[list[int]] = None,
|
||||||
|
activation_alpha: Optional[float] = None,
|
||||||
|
swiglu_limit: Optional[int] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
# print(f"here in triton moe with bias", b1.shape, b1.dtype, b2.shape, b2.dtype)
|
||||||
|
assert use_fp8_w8a8 == False, "use_fp8_w8a8 is not supported"
|
||||||
|
assert per_channel_quant == False, "per_channel_quant is not supported"
|
||||||
|
assert expert_map == None, "expert_map is not supported"
|
||||||
|
assert w1_scale == None, "w1_scale is not supported"
|
||||||
|
assert w2_scale == None, "w2_scale is not supported"
|
||||||
|
assert a1_scale == None, "a1_scale is not supported"
|
||||||
|
assert a2_scale == None, "a2_scale is not supported"
|
||||||
|
assert block_shape == None, "block_shape is not supported"
|
||||||
|
|
||||||
|
# type check
|
||||||
|
assert hidden_states.dtype == torch.bfloat16, "hidden_states must be bfloat16"
|
||||||
|
assert w1.dtype == torch.bfloat16, "w1 must be bfloat16"
|
||||||
|
assert w2.dtype == torch.bfloat16, "w2 must be bfloat16"
|
||||||
|
|
||||||
|
# Shape check
|
||||||
|
assert hidden_states.ndim == 2, "hidden_states must be 2D"
|
||||||
|
assert (
|
||||||
|
hidden_states.shape[-1] == w1.shape[-2]
|
||||||
|
), f"hidden_states shape[-1] {hidden_states.shape} must be equal to w1 shape[-2] {w1.shape}"
|
||||||
|
assert (
|
||||||
|
w2.shape[-1] == w1.shape[1]
|
||||||
|
), f"w2 shape[-1] {w2.shape[-1]} must be equal to w1 shape[1] {w1.shape[1]}"
|
||||||
|
|
||||||
|
# feature check
|
||||||
|
assert inplace == False, "Inplace is not supported in new triton MoE kernel"
|
||||||
|
|
||||||
|
E, _, _ = w1.shape
|
||||||
|
|
||||||
|
if global_num_experts == -1:
|
||||||
|
global_num_experts = E
|
||||||
|
|
||||||
|
device = "cuda"
|
||||||
|
optg = dict()
|
||||||
|
w1, w1_flex = quantize(w1, "bf16", device, **optg)
|
||||||
|
w1_pcg = PrecisionConfig(flex_ctx=FlexCtx(rhs_data=w1_flex))
|
||||||
|
|
||||||
|
w2, w2_flex = quantize(w2, "bf16", device, **optg)
|
||||||
|
w2_pcg = PrecisionConfig(flex_ctx=FlexCtx(rhs_data=w2_flex))
|
||||||
|
|
||||||
|
act = FusedActivation(
|
||||||
|
FnSpecs("swiglu", swiglu_fn, ("alpha", "limit")),
|
||||||
|
(activation_alpha, swiglu_limit),
|
||||||
|
2,
|
||||||
|
)
|
||||||
|
|
||||||
|
intermediate_cache = matmul_ogs(
|
||||||
|
hidden_states,
|
||||||
|
w1,
|
||||||
|
b1,
|
||||||
|
routing_data,
|
||||||
|
gather_indx=gather_indx,
|
||||||
|
precision_config=w1_pcg,
|
||||||
|
gammas=None,
|
||||||
|
fused_activation=act,
|
||||||
|
)
|
||||||
|
|
||||||
|
return matmul_ogs(
|
||||||
|
intermediate_cache,
|
||||||
|
w2,
|
||||||
|
b2,
|
||||||
|
routing_data,
|
||||||
|
scatter_indx=scatter_indx,
|
||||||
|
precision_config=w2_pcg,
|
||||||
|
gammas=routing_data.gate_scal,
|
||||||
|
)
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import torch
|
|||||||
|
|
||||||
from sglang.srt.layers.quantization import deep_gemm_wrapper
|
from sglang.srt.layers.quantization import deep_gemm_wrapper
|
||||||
from sglang.srt.layers.quantization.fp8_kernel import sglang_per_token_group_quant_fp8
|
from sglang.srt.layers.quantization.fp8_kernel import sglang_per_token_group_quant_fp8
|
||||||
|
from sglang.srt.layers.quantization.mxfp4_tensor import MXFP4QuantizeUtil
|
||||||
from sglang.srt.layers.utils import is_sm100_supported
|
from sglang.srt.layers.utils import is_sm100_supported
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -26,6 +27,7 @@ from sglang.srt.layers.quantization.fp8_kernel import (
|
|||||||
)
|
)
|
||||||
from sglang.srt.utils import (
|
from sglang.srt.utils import (
|
||||||
align,
|
align,
|
||||||
|
ceil_div,
|
||||||
get_bool_env_var,
|
get_bool_env_var,
|
||||||
get_cuda_version,
|
get_cuda_version,
|
||||||
get_device_capability,
|
get_device_capability,
|
||||||
@@ -307,6 +309,33 @@ def triton_w8a8_block_fp8_linear(
|
|||||||
return output.to(dtype=input_2d.dtype).view(*output_shape)
|
return output.to(dtype=input_2d.dtype).view(*output_shape)
|
||||||
|
|
||||||
|
|
||||||
|
def dequant_mxfp4(
|
||||||
|
w_block: torch.Tensor,
|
||||||
|
w_scale: torch.Tensor,
|
||||||
|
out_dtype,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
:param w_block: (batch, n, k, 16), uint8, pack two mxfp4 into one byte
|
||||||
|
:param w_scale: (batch, n, k), uint8
|
||||||
|
:return: (batch, n, k * 32), float32
|
||||||
|
"""
|
||||||
|
|
||||||
|
assert w_block.dtype == torch.uint8
|
||||||
|
assert w_scale.dtype == torch.uint8
|
||||||
|
|
||||||
|
batch, n, k, pack_dim = w_block.shape
|
||||||
|
batch_, n_, k_ = w_scale.shape
|
||||||
|
assert pack_dim == 16
|
||||||
|
assert batch == batch_
|
||||||
|
assert n == n_
|
||||||
|
assert k == k_
|
||||||
|
|
||||||
|
out_raw = MXFP4QuantizeUtil.dequantize(
|
||||||
|
quantized_data=w_block, scale=w_scale, dtype=out_dtype, block_sizes=[32]
|
||||||
|
)
|
||||||
|
return out_raw.reshape(batch, n, k * 32)
|
||||||
|
|
||||||
|
|
||||||
def input_to_float8(
|
def input_to_float8(
|
||||||
x: torch.Tensor, dtype: torch.dtype = fp8_dtype
|
x: torch.Tensor, dtype: torch.dtype = fp8_dtype
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
|||||||
133
python/sglang/srt/layers/quantization/mxfp4_tensor.py
Normal file
133
python/sglang/srt/layers/quantization/mxfp4_tensor.py
Normal file
@@ -0,0 +1,133 @@
|
|||||||
|
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
# https://github.com/NVIDIA/TensorRT-Model-Optimizer/blob/main/modelopt/torch/quantization/qtensor/mxfp4_tensor.py
|
||||||
|
class MXFP4QuantizeUtil:
|
||||||
|
E2M1_max = 6.0
|
||||||
|
|
||||||
|
E2M1_values = [0, 0.5, 1, 1.5, 2, 3, 4, 6]
|
||||||
|
E2M1_bounds = torch.tensor([0.25, 0.75, 1.25, 1.75, 2.5, 3.5, 5])
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def quantize(cls, input: torch.Tensor, block_size: int | None) -> tuple:
|
||||||
|
"""Converting a tensor to a quantized format based on MXFP4 quantization. Only E4M3 is supported.
|
||||||
|
Args:
|
||||||
|
input (torch.Tensor): The input tensor to be quantized.
|
||||||
|
block_sizes (dict | None): The block sizes for quantization.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def cast_fp4(x):
|
||||||
|
sign = torch.sign(x)
|
||||||
|
sign_bit = (2 - sign) // 2
|
||||||
|
ord_ = torch.sum(
|
||||||
|
(x.abs().unsqueeze(-1) - cls.E2M1_bounds.to(x.device)) > 0, dim=-1
|
||||||
|
)
|
||||||
|
fp4_val = (sign_bit * 0b1000 + ord_).to(torch.uint8)
|
||||||
|
return fp4_val
|
||||||
|
|
||||||
|
def fuse_uint4_to_uint8(x):
|
||||||
|
# If the last dimension is odd, pad with zeros
|
||||||
|
# If this behavior is not desired, please modify the code accordingly
|
||||||
|
left_side = x[..., 0::2] # Even indices (0, 2, 4...)
|
||||||
|
right_side = x[..., 1::2] # Odd indices (1, 3, 5...)
|
||||||
|
new_data = (
|
||||||
|
right_side.clone() << 4
|
||||||
|
) # Put odd indices (higher addresses) in high bits
|
||||||
|
new_data[
|
||||||
|
..., : left_side.shape[-1]
|
||||||
|
] += left_side # Put even indices in low bits
|
||||||
|
return new_data
|
||||||
|
|
||||||
|
if block_size is None:
|
||||||
|
block_size = 32
|
||||||
|
|
||||||
|
original_shape = input.shape
|
||||||
|
original_dtype = input.dtype
|
||||||
|
input = input.view(-1, block_size)
|
||||||
|
# get scales
|
||||||
|
input_amax = input.abs().max(dim=-1, keepdim=True).values
|
||||||
|
descale = input_amax / cls.E2M1_max
|
||||||
|
min_value = torch.tensor(-127.0, device=descale.device)
|
||||||
|
e8m0_scale = torch.ceil(torch.maximum(torch.log2(descale), min_value))
|
||||||
|
|
||||||
|
input = (input / torch.exp2(e8m0_scale)).view(original_shape)
|
||||||
|
input_q = cast_fp4(input)
|
||||||
|
input_q = fuse_uint4_to_uint8(input_q)
|
||||||
|
e8m0_scale = (e8m0_scale + 127).to(torch.uint8)
|
||||||
|
return cls(original_shape, original_dtype, input_q), e8m0_scale
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def dequantize(cls, quantized_data, dtype: torch.dtype, scale, block_sizes):
|
||||||
|
"""Dequantze MXFP4 packed tensor to a target dtype."""
|
||||||
|
|
||||||
|
def unfuse_uint8_to_uint4(x):
|
||||||
|
"""Unfuse uint8 values back to uint4 values.
|
||||||
|
This is the inverse operation of fuse_uint4_to_uint8.
|
||||||
|
"""
|
||||||
|
# Extract the lower 4 bits (even indices)
|
||||||
|
left_side = x & 0x0F
|
||||||
|
|
||||||
|
# Extract the upper 4 bits (odd indices)
|
||||||
|
right_side = (x >> 4) & 0x0F
|
||||||
|
|
||||||
|
# Create a new tensor with alternating values
|
||||||
|
shape = list(x.shape)
|
||||||
|
shape[-1] = shape[-1] * 2
|
||||||
|
result = torch.zeros(shape, dtype=torch.uint8, device=x.device)
|
||||||
|
|
||||||
|
# Fill in the values - even indices get low bits, odd indices get high bits
|
||||||
|
result[..., 0::2] = left_side # Even indices from low bits
|
||||||
|
result[..., 1::2] = right_side # Odd indices from high bits
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
e8m0_scale = scale
|
||||||
|
block_size = block_sizes[-1]
|
||||||
|
|
||||||
|
# Unfuse the uint8 values back to uint4
|
||||||
|
x_unfused = unfuse_uint8_to_uint4(quantized_data)
|
||||||
|
# Extract sign and magnitude
|
||||||
|
sign = 1 - 2 * ((x_unfused & 0b1000) >> 3).to(
|
||||||
|
torch.float32
|
||||||
|
) # Extract sign bit and convert to +1/-1
|
||||||
|
magnitude = x_unfused & 0b0111 # Extract magnitude bits
|
||||||
|
magnitude = magnitude.to(torch.long)
|
||||||
|
|
||||||
|
# Create a tensor with the E2M1 values
|
||||||
|
values = torch.tensor(cls.E2M1_values, device=quantized_data.device)
|
||||||
|
|
||||||
|
# Use gather to index the values tensor properly
|
||||||
|
# We need to reshape magnitude to match the dimensions we want to gather along
|
||||||
|
original_shape = magnitude.shape
|
||||||
|
x_float = values[magnitude.reshape(-1)].reshape(original_shape)
|
||||||
|
|
||||||
|
# Apply sign and scale
|
||||||
|
x_float = sign.float() * x_float
|
||||||
|
|
||||||
|
# Reshape to apply block-wise scaling
|
||||||
|
x_float = x_float.reshape(-1, block_size)
|
||||||
|
|
||||||
|
# Apply the E8M0 scale
|
||||||
|
scale_factor = torch.exp2(e8m0_scale.float() - 127)
|
||||||
|
scale_factor = scale_factor.reshape(-1, 1) # Reshape for proper broadcasting
|
||||||
|
|
||||||
|
# Apply scaling and reshape back to original shape
|
||||||
|
x_float = x_float * scale_factor
|
||||||
|
|
||||||
|
# Reshape back to the original shape
|
||||||
|
return x_float.reshape(original_shape).to(dtype)
|
||||||
@@ -126,17 +126,23 @@ class UnquantizedLinearMethod(LinearMethodBase):
|
|||||||
class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||||
"""MoE method without quantization."""
|
"""MoE method without quantization."""
|
||||||
|
|
||||||
def __init__(self, use_triton_kernels: bool = False):
|
def __init__(self, use_triton_kernels: bool = False, with_bias: bool = False):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.use_triton_kernels = use_triton_kernels
|
self.use_triton_kernels = use_triton_kernels
|
||||||
|
self.with_bias = with_bias
|
||||||
|
|
||||||
self.triton_kernel_moe_forward = None
|
self.triton_kernel_moe_forward = None
|
||||||
|
self.triton_kernel_moe_with_bias_forward = None
|
||||||
if torch.cuda.is_available() and has_triton_kernels:
|
if torch.cuda.is_available() and has_triton_kernels:
|
||||||
from sglang.srt.layers.moe.fused_moe_triton.triton_kernels_moe import (
|
from sglang.srt.layers.moe.fused_moe_triton.triton_kernels_moe import (
|
||||||
triton_kernel_moe_forward as _tk_forward,
|
triton_kernel_moe_forward as _tk_forward,
|
||||||
)
|
)
|
||||||
|
from sglang.srt.layers.moe.fused_moe_triton.triton_kernels_moe import (
|
||||||
|
triton_kernel_moe_with_bias_forward as _tk_with_bias_forward,
|
||||||
|
)
|
||||||
|
|
||||||
self.triton_kernel_moe_forward = _tk_forward
|
self.triton_kernel_moe_forward = _tk_forward
|
||||||
|
self.triton_kernel_moe_with_bias_forward = _tk_with_bias_forward
|
||||||
|
|
||||||
def create_weights(
|
def create_weights(
|
||||||
self,
|
self,
|
||||||
@@ -158,6 +164,14 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|||||||
layer.register_parameter("w13_weight", w13_weight)
|
layer.register_parameter("w13_weight", w13_weight)
|
||||||
set_weight_attrs(w13_weight, extra_weight_attrs)
|
set_weight_attrs(w13_weight, extra_weight_attrs)
|
||||||
|
|
||||||
|
if self.with_bias:
|
||||||
|
w13_weight_bias = torch.nn.Parameter(
|
||||||
|
torch.empty(num_experts, 2 * intermediate_size, dtype=torch.float32),
|
||||||
|
requires_grad=False,
|
||||||
|
)
|
||||||
|
layer.register_parameter("w13_weight_bias", w13_weight_bias)
|
||||||
|
set_weight_attrs(w13_weight_bias, extra_weight_attrs)
|
||||||
|
|
||||||
# down_proj (row parallel)
|
# down_proj (row parallel)
|
||||||
w2_weight_n, w2_weight_k = (
|
w2_weight_n, w2_weight_k = (
|
||||||
hidden_size,
|
hidden_size,
|
||||||
@@ -172,6 +186,14 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|||||||
layer.register_parameter("w2_weight", w2_weight)
|
layer.register_parameter("w2_weight", w2_weight)
|
||||||
set_weight_attrs(w2_weight, extra_weight_attrs)
|
set_weight_attrs(w2_weight, extra_weight_attrs)
|
||||||
|
|
||||||
|
if self.with_bias:
|
||||||
|
w2_weight_bias = torch.nn.Parameter(
|
||||||
|
torch.empty(num_experts, hidden_size, dtype=torch.float32),
|
||||||
|
requires_grad=False,
|
||||||
|
)
|
||||||
|
layer.register_parameter("w2_weight_bias", w2_weight_bias)
|
||||||
|
set_weight_attrs(w2_weight_bias, extra_weight_attrs)
|
||||||
|
|
||||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||||
if _use_aiter:
|
if _use_aiter:
|
||||||
layer.w13_weight = torch.nn.Parameter(
|
layer.w13_weight = torch.nn.Parameter(
|
||||||
@@ -202,7 +224,14 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|||||||
inplace: bool = True,
|
inplace: bool = True,
|
||||||
no_combine: bool = False,
|
no_combine: bool = False,
|
||||||
routed_scaling_factor: Optional[float] = None,
|
routed_scaling_factor: Optional[float] = None,
|
||||||
|
activation_alpha: Optional[float] = None,
|
||||||
|
swiglu_limit: Optional[float] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
kwargs = {}
|
||||||
|
if activation_alpha is not None:
|
||||||
|
kwargs["activation_alpha"] = activation_alpha
|
||||||
|
if swiglu_limit is not None:
|
||||||
|
kwargs["swiglu_limit"] = swiglu_limit
|
||||||
|
|
||||||
return self.forward(
|
return self.forward(
|
||||||
x=x,
|
x=x,
|
||||||
@@ -213,6 +242,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|||||||
inplace=inplace,
|
inplace=inplace,
|
||||||
no_combine=no_combine,
|
no_combine=no_combine,
|
||||||
routed_scaling_factor=routed_scaling_factor,
|
routed_scaling_factor=routed_scaling_factor,
|
||||||
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward_cuda(
|
def forward_cuda(
|
||||||
@@ -226,15 +256,30 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|||||||
inplace: bool = True,
|
inplace: bool = True,
|
||||||
no_combine: bool = False,
|
no_combine: bool = False,
|
||||||
routed_scaling_factor: Optional[float] = None,
|
routed_scaling_factor: Optional[float] = None,
|
||||||
|
activation_alpha: Optional[float] = None,
|
||||||
|
swiglu_limit: Optional[float] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
|
||||||
if self.use_triton_kernels:
|
if self.use_triton_kernels:
|
||||||
return self.triton_kernel_moe_forward(
|
if self.with_bias:
|
||||||
hidden_states=x,
|
return self.triton_kernel_moe_with_bias_forward(
|
||||||
w1=layer.w13_weight,
|
hidden_states=x,
|
||||||
w2=layer.w2_weight,
|
w1=layer.w13_weight,
|
||||||
topk_output=topk_output,
|
w2=layer.w2_weight,
|
||||||
)
|
b1=layer.w13_weight_bias,
|
||||||
|
b2=layer.w2_weight_bias,
|
||||||
|
topk_output=topk_output,
|
||||||
|
activation=activation,
|
||||||
|
activation_alpha=activation_alpha,
|
||||||
|
swiglu_limit=swiglu_limit,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return self.triton_kernel_moe_forward(
|
||||||
|
hidden_states=x,
|
||||||
|
w1=layer.w13_weight,
|
||||||
|
w2=layer.w2_weight,
|
||||||
|
topk_output=topk_output,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
if _use_aiter:
|
if _use_aiter:
|
||||||
assert not no_combine, "unsupported"
|
assert not no_combine, "unsupported"
|
||||||
|
|||||||
@@ -917,8 +917,10 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|||||||
|
|
||||||
is_hybrid = False
|
is_hybrid = False
|
||||||
if isinstance(token_to_kv_pool_allocator, SWATokenToKVPoolAllocator):
|
if isinstance(token_to_kv_pool_allocator, SWATokenToKVPoolAllocator):
|
||||||
assert isinstance(tree_cache, SWARadixCache) or isinstance(
|
assert (
|
||||||
tree_cache, SWAChunkCache
|
tree_cache is None
|
||||||
|
or isinstance(tree_cache, SWARadixCache)
|
||||||
|
or isinstance(tree_cache, SWAChunkCache)
|
||||||
), "SWARadixCache or SWAChunkCache is required for SWATokenToKVPoolAllocator"
|
), "SWARadixCache or SWAChunkCache is required for SWATokenToKVPoolAllocator"
|
||||||
is_hybrid = True
|
is_hybrid = True
|
||||||
|
|
||||||
|
|||||||
923
python/sglang/srt/models/gpt_oss.py
Normal file
923
python/sglang/srt/models/gpt_oss.py
Normal file
@@ -0,0 +1,923 @@
|
|||||||
|
# Copyright 2023-2024 SGLang Team
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
"""Inference-only GptOss model compatible with HuggingFace weights."""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from collections.abc import Iterable
|
||||||
|
from functools import partial
|
||||||
|
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
from transformers import PretrainedConfig
|
||||||
|
|
||||||
|
from sglang.srt.distributed import (
|
||||||
|
get_moe_tensor_parallel_rank,
|
||||||
|
get_pp_group,
|
||||||
|
get_tensor_model_parallel_rank,
|
||||||
|
get_tensor_model_parallel_world_size,
|
||||||
|
tensor_model_parallel_all_reduce,
|
||||||
|
)
|
||||||
|
from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
|
||||||
|
from sglang.srt.eplb.expert_location import ModelConfigForExpertLocation
|
||||||
|
from sglang.srt.layers.communicator import LayerCommunicator, LayerScatterModes
|
||||||
|
from sglang.srt.layers.dp_attention import (
|
||||||
|
get_attention_tp_rank,
|
||||||
|
get_attention_tp_size,
|
||||||
|
get_local_attention_dp_size,
|
||||||
|
)
|
||||||
|
from sglang.srt.layers.layernorm import RMSNorm
|
||||||
|
from sglang.srt.layers.linear import (
|
||||||
|
QKVParallelLinear,
|
||||||
|
ReplicatedLinear,
|
||||||
|
RowParallelLinear,
|
||||||
|
)
|
||||||
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||||
|
from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class
|
||||||
|
from sglang.srt.layers.moe.topk import TopK
|
||||||
|
from sglang.srt.layers.moe.utils import DeepEPMode
|
||||||
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||||
|
from sglang.srt.layers.quantization.fp8_utils import dequant_mxfp4
|
||||||
|
from sglang.srt.layers.radix_attention import RadixAttention
|
||||||
|
from sglang.srt.layers.rotary_embedding import get_rope
|
||||||
|
from sglang.srt.layers.utils import PPMissingLayer, get_layer_id
|
||||||
|
from sglang.srt.layers.vocab_parallel_embedding import (
|
||||||
|
ParallelLMHead,
|
||||||
|
VocabParallelEmbedding,
|
||||||
|
)
|
||||||
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||||
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
|
||||||
|
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||||
|
from sglang.srt.utils import add_prefix, make_layers
|
||||||
|
|
||||||
|
|
||||||
|
class GptOssConfig(PretrainedConfig):
|
||||||
|
model_type = "gpt_oss"
|
||||||
|
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
# Aligned with HF's implementation, using sliding window inclusive with the last token
|
||||||
|
# SGLang assumes exclusive
|
||||||
|
def get_attention_sliding_window_size(config):
|
||||||
|
return config.sliding_window - 1
|
||||||
|
|
||||||
|
|
||||||
|
class GptOssSparseMoeBlock(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
layer_id: int,
|
||||||
|
config: GptOssConfig,
|
||||||
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
prefix: str = "",
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.tp_size = get_tensor_model_parallel_world_size()
|
||||||
|
self.layer_id = layer_id
|
||||||
|
self.activation = config.hidden_act
|
||||||
|
self.activation_alpha = getattr(config, "hidden_act_alpha", 1.702)
|
||||||
|
self.swiglu_limit = config.swiglu_limit
|
||||||
|
if self.tp_size > config.num_local_experts:
|
||||||
|
raise ValueError(
|
||||||
|
f"Tensor parallel size {self.tp_size} is greater than "
|
||||||
|
f"the number of experts {config.num_local_experts}."
|
||||||
|
)
|
||||||
|
|
||||||
|
self.topk = TopK(
|
||||||
|
top_k=config.num_experts_per_tok,
|
||||||
|
renormalize=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
experts_type = get_moe_impl_class()
|
||||||
|
extra_kwargs = {}
|
||||||
|
if experts_type.__name__ == "FusedMoE":
|
||||||
|
extra_kwargs = {
|
||||||
|
"enable_flashinfer_cutlass_moe": global_server_args_dict[
|
||||||
|
"enable_flashinfer_cutlass_moe"
|
||||||
|
],
|
||||||
|
"use_weight_loader_fused": True, # for moe gate_up_proj and down_proj and their bias loading
|
||||||
|
}
|
||||||
|
self.experts = experts_type(
|
||||||
|
num_experts=config.num_local_experts
|
||||||
|
+ global_server_args_dict["ep_num_redundant_experts"],
|
||||||
|
top_k=config.num_experts_per_tok,
|
||||||
|
layer_id=layer_id,
|
||||||
|
hidden_size=config.hidden_size,
|
||||||
|
intermediate_size=config.intermediate_size,
|
||||||
|
quant_config=quant_config,
|
||||||
|
activation=self.activation,
|
||||||
|
activation_alpha=self.activation_alpha,
|
||||||
|
swiglu_limit=self.swiglu_limit,
|
||||||
|
with_bias=True,
|
||||||
|
prefix=add_prefix("experts", prefix),
|
||||||
|
**(
|
||||||
|
dict(deepep_mode=DeepEPMode[global_server_args_dict["deepep_mode"]])
|
||||||
|
if global_server_args_dict["moe_a2a_backend"].is_deepep()
|
||||||
|
else {}
|
||||||
|
),
|
||||||
|
**extra_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.router = ReplicatedLinear(
|
||||||
|
config.hidden_size,
|
||||||
|
config.num_local_experts,
|
||||||
|
bias=True,
|
||||||
|
quant_config=None,
|
||||||
|
prefix=add_prefix("gate", prefix),
|
||||||
|
params_dtype=config.torch_dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self, hidden_states: torch.Tensor, forward_batch: Optional[ForwardBatch] = None
|
||||||
|
) -> torch.Tensor:
|
||||||
|
if not global_server_args_dict["moe_a2a_backend"].is_deepep():
|
||||||
|
return self.forward_normal(hidden_states)
|
||||||
|
else:
|
||||||
|
raise Exception("forward_deepep branch not implemented yet")
|
||||||
|
|
||||||
|
def get_moe_weights(self):
|
||||||
|
return [
|
||||||
|
x.data
|
||||||
|
for name, x in self.experts.named_parameters()
|
||||||
|
if name not in ["correction_bias"]
|
||||||
|
]
|
||||||
|
|
||||||
|
def forward_normal(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||||
|
num_tokens, hidden_dim = hidden_states.shape
|
||||||
|
hidden_states = hidden_states.view(-1, hidden_dim)
|
||||||
|
|
||||||
|
# router_logits: (num_tokens, n_experts)
|
||||||
|
router_logits, _ = self.router(hidden_states)
|
||||||
|
|
||||||
|
kwargs = {"hidden_states": hidden_states}
|
||||||
|
if self.topk is not None:
|
||||||
|
kwargs["topk_output"] = self.topk(hidden_states, router_logits)
|
||||||
|
else:
|
||||||
|
kwargs["router_logits"] = router_logits
|
||||||
|
final_hidden_states = self.experts(**kwargs)
|
||||||
|
|
||||||
|
if self.tp_size > 1:
|
||||||
|
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
|
||||||
|
|
||||||
|
ans = final_hidden_states.view(num_tokens, hidden_dim)
|
||||||
|
return ans
|
||||||
|
|
||||||
|
|
||||||
|
class GptOssAttention(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
hidden_size: int,
|
||||||
|
num_heads: int,
|
||||||
|
num_kv_heads: int,
|
||||||
|
layer_id: int = 0,
|
||||||
|
rope_theta: float = 10000,
|
||||||
|
rope_scaling: Optional[Dict[str, Any]] = None,
|
||||||
|
max_position_embeddings: int = 8192,
|
||||||
|
head_dim: Optional[int] = None,
|
||||||
|
rms_norm_eps: float = 1e-06,
|
||||||
|
attention_bias: bool = False,
|
||||||
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
prefix: str = "",
|
||||||
|
sliding_window_size: int = -1, # if -1, normal attention, else, window attention.
|
||||||
|
layer_type: str = "",
|
||||||
|
params_dtype: torch.dtype = torch.bfloat16,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.sliding_window_size = sliding_window_size
|
||||||
|
|
||||||
|
attn_tp_rank = get_attention_tp_rank()
|
||||||
|
attn_tp_size = get_attention_tp_size()
|
||||||
|
|
||||||
|
self.total_num_heads = num_heads
|
||||||
|
assert self.total_num_heads % attn_tp_size == 0
|
||||||
|
self.num_heads = self.total_num_heads // attn_tp_size
|
||||||
|
self.total_num_kv_heads = num_kv_heads
|
||||||
|
if self.total_num_kv_heads >= attn_tp_size:
|
||||||
|
# Number of KV heads is greater than TP size, so we partition
|
||||||
|
# the KV heads across multiple tensor parallel GPUs.
|
||||||
|
assert self.total_num_kv_heads % attn_tp_size == 0
|
||||||
|
else:
|
||||||
|
# Number of KV heads is less than TP size, so we replicate
|
||||||
|
# the KV heads across multiple tensor parallel GPUs.
|
||||||
|
assert attn_tp_size % self.total_num_kv_heads == 0
|
||||||
|
self.num_kv_heads = max(1, self.total_num_kv_heads // attn_tp_size)
|
||||||
|
self.head_dim = head_dim or hidden_size // self.total_num_heads
|
||||||
|
self.q_size = self.num_heads * self.head_dim
|
||||||
|
self.kv_size = self.num_kv_heads * self.head_dim
|
||||||
|
self.scaling = self.head_dim**-0.5
|
||||||
|
self.rope_theta = rope_theta
|
||||||
|
self.max_position_embeddings = max_position_embeddings
|
||||||
|
self.tp_rank = get_tensor_model_parallel_rank()
|
||||||
|
|
||||||
|
self.qkv_proj = QKVParallelLinear(
|
||||||
|
hidden_size,
|
||||||
|
self.head_dim,
|
||||||
|
self.total_num_heads,
|
||||||
|
self.total_num_kv_heads,
|
||||||
|
bias=attention_bias,
|
||||||
|
params_dtype=params_dtype,
|
||||||
|
quant_config=quant_config,
|
||||||
|
tp_rank=attn_tp_rank,
|
||||||
|
tp_size=attn_tp_size,
|
||||||
|
prefix=add_prefix("qkv_proj", prefix),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.sinks = nn.Parameter(
|
||||||
|
torch.empty(self.num_heads, dtype=params_dtype), requires_grad=False
|
||||||
|
)
|
||||||
|
|
||||||
|
self.o_proj = RowParallelLinear(
|
||||||
|
self.total_num_heads * self.head_dim,
|
||||||
|
hidden_size,
|
||||||
|
bias=attention_bias,
|
||||||
|
quant_config=quant_config,
|
||||||
|
tp_rank=attn_tp_rank,
|
||||||
|
tp_size=attn_tp_size,
|
||||||
|
reduce_results=False,
|
||||||
|
params_dtype=params_dtype,
|
||||||
|
prefix=add_prefix("o_proj", prefix),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.rotary_emb = get_rope(
|
||||||
|
self.head_dim,
|
||||||
|
rotary_dim=self.head_dim,
|
||||||
|
max_position=max_position_embeddings,
|
||||||
|
base=rope_theta,
|
||||||
|
rope_scaling=rope_scaling,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert layer_type in {"sliding_attention", "full_attention"}
|
||||||
|
use_sliding_window = layer_type == "sliding_attention"
|
||||||
|
self.attn = RadixAttention(
|
||||||
|
self.num_heads,
|
||||||
|
self.head_dim,
|
||||||
|
self.scaling,
|
||||||
|
num_kv_heads=self.num_kv_heads,
|
||||||
|
layer_id=layer_id,
|
||||||
|
prefix=add_prefix("attn", prefix),
|
||||||
|
sliding_window_size=(sliding_window_size if use_sliding_window else -1),
|
||||||
|
)
|
||||||
|
self.layer_id = layer_id
|
||||||
|
|
||||||
|
def forward_prepare(
|
||||||
|
self,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
forward_batch: ForwardBatch,
|
||||||
|
):
|
||||||
|
if hidden_states.shape[0] == 0:
|
||||||
|
return hidden_states, forward_batch, None
|
||||||
|
qkv, _ = self.qkv_proj(hidden_states)
|
||||||
|
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||||
|
q, k = self.rotary_emb(positions, q, k)
|
||||||
|
inner_state = q, k, v, forward_batch
|
||||||
|
return None, forward_batch, inner_state
|
||||||
|
|
||||||
|
def forward_core(self, intermediate_state):
|
||||||
|
hidden_states, forward_batch, inner_state = intermediate_state
|
||||||
|
if inner_state is None:
|
||||||
|
return hidden_states
|
||||||
|
attn_output = self.attn(*inner_state, sk=self.sinks)
|
||||||
|
output, _ = self.o_proj(attn_output)
|
||||||
|
return output
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
forward_batch: ForwardBatch,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
s = self.forward_prepare(
|
||||||
|
positions=positions,
|
||||||
|
hidden_states=hidden_states,
|
||||||
|
forward_batch=forward_batch,
|
||||||
|
)
|
||||||
|
return self.forward_core(s)
|
||||||
|
|
||||||
|
|
||||||
|
class GptOssDecoderLayer(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: GptOssConfig,
|
||||||
|
layer_id: int,
|
||||||
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
prefix: str = "",
|
||||||
|
sliding_window_size: int | None = None,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
self.hidden_size = config.hidden_size
|
||||||
|
rope_theta = getattr(config, "rope_theta", 10000)
|
||||||
|
rope_scaling = getattr(config, "rope_scaling", None)
|
||||||
|
max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
|
||||||
|
head_dim = getattr(
|
||||||
|
config, "head_dim", config.hidden_size // config.num_attention_heads
|
||||||
|
)
|
||||||
|
rms_norm_eps = config.rms_norm_eps
|
||||||
|
attention_bias = config.attention_bias
|
||||||
|
|
||||||
|
if sliding_window_size is None:
|
||||||
|
self.sliding_window_size = get_attention_sliding_window_size(self.config)
|
||||||
|
else:
|
||||||
|
self.sliding_window_size = sliding_window_size
|
||||||
|
|
||||||
|
self.self_attn = GptOssAttention(
|
||||||
|
hidden_size=self.hidden_size,
|
||||||
|
num_heads=config.num_attention_heads,
|
||||||
|
num_kv_heads=config.num_key_value_heads,
|
||||||
|
layer_id=layer_id,
|
||||||
|
rope_theta=rope_theta,
|
||||||
|
rope_scaling=rope_scaling,
|
||||||
|
max_position_embeddings=max_position_embeddings,
|
||||||
|
head_dim=head_dim,
|
||||||
|
rms_norm_eps=rms_norm_eps,
|
||||||
|
attention_bias=attention_bias,
|
||||||
|
quant_config=quant_config,
|
||||||
|
prefix=add_prefix("self_attn", prefix),
|
||||||
|
sliding_window_size=self.sliding_window_size,
|
||||||
|
layer_type=config.layer_types[layer_id],
|
||||||
|
params_dtype=config.torch_dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.layer_id = layer_id
|
||||||
|
|
||||||
|
self.attn_tp_size = get_attention_tp_size()
|
||||||
|
self.attn_tp_rank = get_attention_tp_rank()
|
||||||
|
self.local_dp_size = get_local_attention_dp_size()
|
||||||
|
|
||||||
|
# GptOss all layers are sparse and have no nextn now
|
||||||
|
self.is_layer_sparse = True
|
||||||
|
is_previous_layer_sparse = True
|
||||||
|
|
||||||
|
self.layer_scatter_modes = LayerScatterModes.init_new(
|
||||||
|
layer_id=layer_id,
|
||||||
|
num_layers=config.num_hidden_layers,
|
||||||
|
is_layer_sparse=self.is_layer_sparse,
|
||||||
|
is_previous_layer_sparse=is_previous_layer_sparse,
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.is_layer_sparse:
|
||||||
|
self.mlp = GptOssSparseMoeBlock(
|
||||||
|
layer_id=self.layer_id,
|
||||||
|
config=config,
|
||||||
|
quant_config=quant_config,
|
||||||
|
prefix=add_prefix("mlp", prefix),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"Dense MLP is not implemented for GptOssDecoderLayer. "
|
||||||
|
"Please use GptOssSparseMoeBlock instead."
|
||||||
|
)
|
||||||
|
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||||
|
self.post_attention_layernorm = RMSNorm(
|
||||||
|
config.hidden_size, eps=config.rms_norm_eps
|
||||||
|
)
|
||||||
|
|
||||||
|
self.layer_communicator = LayerCommunicator(
|
||||||
|
layer_scatter_modes=self.layer_scatter_modes,
|
||||||
|
input_layernorm=self.input_layernorm,
|
||||||
|
post_attention_layernorm=self.post_attention_layernorm,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
forward_batch: ForwardBatch,
|
||||||
|
residual: Optional[torch.Tensor],
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
hidden_states, residual = self.layer_communicator.prepare_attn(
|
||||||
|
hidden_states, residual, forward_batch
|
||||||
|
)
|
||||||
|
|
||||||
|
if hidden_states.shape[0] != 0:
|
||||||
|
hidden_states = self.self_attn(
|
||||||
|
positions=positions,
|
||||||
|
hidden_states=hidden_states,
|
||||||
|
forward_batch=forward_batch,
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states, residual = self.layer_communicator.prepare_mlp(
|
||||||
|
hidden_states, residual, forward_batch
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states = self.mlp(hidden_states, forward_batch)
|
||||||
|
|
||||||
|
hidden_states, residual = self.layer_communicator.postprocess_layer(
|
||||||
|
hidden_states, residual, forward_batch
|
||||||
|
)
|
||||||
|
|
||||||
|
return hidden_states, residual
|
||||||
|
|
||||||
|
|
||||||
|
class GptOssModel(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: PretrainedConfig,
|
||||||
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
prefix: str = "",
|
||||||
|
decoder_layer_type: type[nn.Module] = GptOssDecoderLayer,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.padding_idx = config.pad_token_id
|
||||||
|
self.vocab_size = config.vocab_size
|
||||||
|
self.pp_group = get_pp_group()
|
||||||
|
|
||||||
|
if self.pp_group.is_first_rank:
|
||||||
|
self.embed_tokens = VocabParallelEmbedding(
|
||||||
|
config.vocab_size,
|
||||||
|
config.hidden_size,
|
||||||
|
enable_tp=not global_server_args_dict["enable_dp_attention"],
|
||||||
|
prefix=add_prefix("embed_tokens", prefix),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.embed_tokens = PPMissingLayer()
|
||||||
|
|
||||||
|
# Use the provided decoder layer type or default to GptOssDecoderLayer
|
||||||
|
decoder_layer_type = decoder_layer_type or GptOssDecoderLayer
|
||||||
|
self.layers, self.start_layer, self.end_layer = make_layers(
|
||||||
|
config.num_hidden_layers,
|
||||||
|
lambda idx, prefix: decoder_layer_type(
|
||||||
|
layer_id=idx,
|
||||||
|
config=config,
|
||||||
|
quant_config=quant_config,
|
||||||
|
prefix=prefix,
|
||||||
|
),
|
||||||
|
pp_rank=self.pp_group.rank_in_group,
|
||||||
|
pp_size=self.pp_group.world_size,
|
||||||
|
prefix=add_prefix("layers", prefix),
|
||||||
|
)
|
||||||
|
if self.pp_group.is_last_rank:
|
||||||
|
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||||
|
else:
|
||||||
|
self.norm = PPMissingLayer(return_tuple=True)
|
||||||
|
|
||||||
|
self.layers_to_capture = []
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
forward_batch: ForwardBatch,
|
||||||
|
input_embeds: torch.Tensor = None,
|
||||||
|
pp_proxy_tensors: Optional[PPProxyTensors] = None,
|
||||||
|
) -> Union[torch.Tensor, PPProxyTensors]:
|
||||||
|
if self.pp_group.is_first_rank:
|
||||||
|
if input_embeds is None:
|
||||||
|
hidden_states = self.embed_tokens(input_ids)
|
||||||
|
else:
|
||||||
|
hidden_states = input_embeds
|
||||||
|
residual = None
|
||||||
|
else:
|
||||||
|
assert pp_proxy_tensors is not None
|
||||||
|
hidden_states = pp_proxy_tensors["hidden_states"]
|
||||||
|
residual = pp_proxy_tensors["residual"]
|
||||||
|
|
||||||
|
aux_hidden_states = []
|
||||||
|
for i in range(self.start_layer, self.end_layer):
|
||||||
|
with get_global_expert_distribution_recorder().with_current_layer(i):
|
||||||
|
if i in self.layers_to_capture:
|
||||||
|
aux_hidden_states.append(hidden_states + residual)
|
||||||
|
layer = self.layers[i]
|
||||||
|
hidden_states, residual = layer(
|
||||||
|
positions, hidden_states, forward_batch, residual
|
||||||
|
)
|
||||||
|
if not self.pp_group.is_last_rank:
|
||||||
|
return PPProxyTensors(
|
||||||
|
{
|
||||||
|
"hidden_states": hidden_states,
|
||||||
|
"residual": residual,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
if hidden_states.shape[0] != 0:
|
||||||
|
if residual is None:
|
||||||
|
hidden_states = self.norm(hidden_states)
|
||||||
|
else:
|
||||||
|
hidden_states, _ = self.norm(hidden_states, residual)
|
||||||
|
if len(aux_hidden_states) == 0:
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
return hidden_states, aux_hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class GptOssForCausalLM(nn.Module):
|
||||||
|
fall_back_to_pt_during_load = False
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: GptOssConfig,
|
||||||
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
prefix: str = "",
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.pp_group = get_pp_group()
|
||||||
|
self.config = config
|
||||||
|
self.quant_config = quant_config
|
||||||
|
self.model = GptOssModel(
|
||||||
|
config, quant_config, prefix=add_prefix("model", prefix)
|
||||||
|
)
|
||||||
|
self.lm_head = ParallelLMHead(
|
||||||
|
config.vocab_size,
|
||||||
|
config.hidden_size,
|
||||||
|
quant_config=quant_config,
|
||||||
|
prefix=add_prefix("lm_head", prefix),
|
||||||
|
use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"],
|
||||||
|
)
|
||||||
|
self.logits_processor = LogitsProcessor(config)
|
||||||
|
self.capture_aux_hidden_states = False
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
forward_batch: ForwardBatch,
|
||||||
|
input_embeds: torch.Tensor = None,
|
||||||
|
pp_proxy_tensors: Optional[PPProxyTensors] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
hidden_states = self.model(
|
||||||
|
input_ids,
|
||||||
|
positions,
|
||||||
|
forward_batch,
|
||||||
|
input_embeds,
|
||||||
|
pp_proxy_tensors=pp_proxy_tensors,
|
||||||
|
)
|
||||||
|
|
||||||
|
aux_hidden_states = None
|
||||||
|
if self.capture_aux_hidden_states:
|
||||||
|
hidden_states, aux_hidden_states = hidden_states
|
||||||
|
|
||||||
|
if self.pp_group.is_last_rank:
|
||||||
|
return self.logits_processor(
|
||||||
|
input_ids,
|
||||||
|
hidden_states,
|
||||||
|
self.lm_head,
|
||||||
|
forward_batch,
|
||||||
|
aux_hidden_states,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
@property
|
||||||
|
def start_layer(self):
|
||||||
|
return self.model.start_layer
|
||||||
|
|
||||||
|
@property
|
||||||
|
def end_layer(self):
|
||||||
|
return self.model.end_layer
|
||||||
|
|
||||||
|
def _get_default_weight_mapping(self):
|
||||||
|
"""Generate default weight name mapping for GptOss safetensors."""
|
||||||
|
weight_mapping = {}
|
||||||
|
|
||||||
|
# Map router weights to gate
|
||||||
|
weight_mapping["embedding.weight"] = "model.embed_tokens.weight"
|
||||||
|
weight_mapping["unembedding.weight"] = "lm_head.weight"
|
||||||
|
weight_mapping["norm.scale"] = "model.norm.weight"
|
||||||
|
for layer_id in range(self.config.num_hidden_layers):
|
||||||
|
weight_mapping[f"block.{layer_id}.attn.q_proj.weight"] = (
|
||||||
|
f"model.layers.{layer_id}.self_attn.q_proj.weight"
|
||||||
|
)
|
||||||
|
weight_mapping[f"block.{layer_id}.attn.q_proj.bias"] = (
|
||||||
|
f"model.layers.{layer_id}.self_attn.q_proj.bias"
|
||||||
|
)
|
||||||
|
|
||||||
|
weight_mapping[f"block.{layer_id}.attn.k_proj.weight"] = (
|
||||||
|
f"model.layers.{layer_id}.self_attn.k_proj.weight"
|
||||||
|
)
|
||||||
|
weight_mapping[f"block.{layer_id}.attn.k_proj.bias"] = (
|
||||||
|
f"model.layers.{layer_id}.self_attn.k_proj.bias"
|
||||||
|
)
|
||||||
|
|
||||||
|
weight_mapping[f"block.{layer_id}.attn.v_proj.weight"] = (
|
||||||
|
f"model.layers.{layer_id}.self_attn.v_proj.weight"
|
||||||
|
)
|
||||||
|
weight_mapping[f"block.{layer_id}.attn.v_proj.bias"] = (
|
||||||
|
f"model.layers.{layer_id}.self_attn.v_proj.bias"
|
||||||
|
)
|
||||||
|
|
||||||
|
weight_mapping[f"block.{layer_id}.attn.out.weight"] = (
|
||||||
|
f"model.layers.{layer_id}.self_attn.o_proj.weight"
|
||||||
|
)
|
||||||
|
weight_mapping[f"block.{layer_id}.attn.out.bias"] = (
|
||||||
|
f"model.layers.{layer_id}.self_attn.o_proj.bias"
|
||||||
|
)
|
||||||
|
weight_mapping[f"block.{layer_id}.attn.sinks"] = (
|
||||||
|
f"model.layers.{layer_id}.self_attn.sinks"
|
||||||
|
)
|
||||||
|
weight_mapping[f"block.{layer_id}.attn.norm.scale"] = (
|
||||||
|
f"model.layers.{layer_id}.input_layernorm.weight"
|
||||||
|
)
|
||||||
|
|
||||||
|
weight_mapping[f"block.{layer_id}.mlp.gate.weight"] = (
|
||||||
|
f"model.layers.{layer_id}.mlp.router.weight"
|
||||||
|
)
|
||||||
|
weight_mapping[f"block.{layer_id}.mlp.gate.bias"] = (
|
||||||
|
f"model.layers.{layer_id}.mlp.router.bias"
|
||||||
|
)
|
||||||
|
weight_mapping[f"block.{layer_id}.mlp.norm.scale"] = (
|
||||||
|
f"model.layers.{layer_id}.post_attention_layernorm.weight"
|
||||||
|
)
|
||||||
|
weight_mapping[f"block.{layer_id}.mlp.experts.gate_up_proj"] = (
|
||||||
|
f"model.layers.{layer_id}.mlp.experts.gate_up_proj"
|
||||||
|
)
|
||||||
|
weight_mapping[f"block.{layer_id}.mlp.gate_up_proj_bias"] = (
|
||||||
|
f"model.layers.{layer_id}.mlp.experts.gate_up_proj_bias"
|
||||||
|
)
|
||||||
|
weight_mapping[f"block.{layer_id}.mlp.down_proj"] = (
|
||||||
|
f"model.layers.{layer_id}.mlp.experts.mlp2_weight"
|
||||||
|
)
|
||||||
|
weight_mapping[f"block.{layer_id}.mlp.down_proj_bias"] = (
|
||||||
|
f"model.layers.{layer_id}.mlp.experts.mlp2_bias"
|
||||||
|
)
|
||||||
|
|
||||||
|
return weight_mapping
|
||||||
|
|
||||||
|
def load_weights(
|
||||||
|
self,
|
||||||
|
weights: Iterable[Tuple[str, torch.Tensor]],
|
||||||
|
is_nextn: bool = False,
|
||||||
|
weight_name_mapping: dict = None,
|
||||||
|
):
|
||||||
|
tp_rank = get_tensor_model_parallel_rank()
|
||||||
|
if is_nextn:
|
||||||
|
logging.warning(
|
||||||
|
"Loading weights for nextn is currently not supported in GptOssForCausalLM. "
|
||||||
|
)
|
||||||
|
return
|
||||||
|
weights = _canonicalize_weights(self.config, weights)
|
||||||
|
weights = sorted(weights, key=lambda x: x[0]) # Sort by name for consistency
|
||||||
|
|
||||||
|
new_weights = []
|
||||||
|
for name, p in weights:
|
||||||
|
if "qkv.weight" in name:
|
||||||
|
q_proj, k_proj, v_proj = p.split(
|
||||||
|
[
|
||||||
|
self.config.num_attention_heads * self.config.head_dim,
|
||||||
|
self.config.num_key_value_heads * self.config.head_dim,
|
||||||
|
self.config.num_key_value_heads * self.config.head_dim,
|
||||||
|
],
|
||||||
|
dim=0,
|
||||||
|
)
|
||||||
|
new_weights.append(
|
||||||
|
(f"{name.replace('qkv.weight', 'q_proj.weight')}", q_proj)
|
||||||
|
)
|
||||||
|
new_weights.append(
|
||||||
|
(f"{name.replace('qkv.weight', 'k_proj.weight')}", k_proj)
|
||||||
|
)
|
||||||
|
new_weights.append(
|
||||||
|
(f"{name.replace('qkv.weight', 'v_proj.weight')}", v_proj)
|
||||||
|
)
|
||||||
|
elif "qkv.bias" in name:
|
||||||
|
q_bias, k_bias, v_bias = p.split(
|
||||||
|
[
|
||||||
|
self.config.num_attention_heads * self.config.head_dim,
|
||||||
|
self.config.num_key_value_heads * self.config.head_dim,
|
||||||
|
self.config.num_key_value_heads * self.config.head_dim,
|
||||||
|
],
|
||||||
|
dim=0,
|
||||||
|
)
|
||||||
|
new_weights.append(
|
||||||
|
(f"{name.replace('qkv.bias', 'q_proj.bias')}", q_bias)
|
||||||
|
)
|
||||||
|
new_weights.append(
|
||||||
|
(f"{name.replace('qkv.bias', 'k_proj.bias')}", k_bias)
|
||||||
|
)
|
||||||
|
new_weights.append(
|
||||||
|
(f"{name.replace('qkv.bias', 'v_proj.bias')}", v_bias)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
new_weights.append((name, p))
|
||||||
|
weights = new_weights
|
||||||
|
|
||||||
|
# Use provided weight name mapping if available, otherwise use default
|
||||||
|
if weight_name_mapping is None:
|
||||||
|
weight_name_mapping = self._get_default_weight_mapping()
|
||||||
|
else:
|
||||||
|
# Merge with default mapping
|
||||||
|
default_mapping = self._get_default_weight_mapping()
|
||||||
|
default_mapping.update(weight_name_mapping)
|
||||||
|
weight_name_mapping = default_mapping
|
||||||
|
|
||||||
|
stacked_params_mapping = [
|
||||||
|
# (param_name, shard_name, shard_id)
|
||||||
|
("qkv_proj", "q_proj", "q"),
|
||||||
|
("qkv_proj", "k_proj", "k"),
|
||||||
|
("qkv_proj", "v_proj", "v"),
|
||||||
|
]
|
||||||
|
|
||||||
|
expert_params_mapping = get_moe_impl_class().make_expert_params_mapping_fused(
|
||||||
|
ckpt_gate_up_proj_name="gate_up_proj",
|
||||||
|
ckpt_down_proj_name="down_proj",
|
||||||
|
ckpt_gate_up_proj_bias_name="gate_up_proj_bias",
|
||||||
|
ckpt_down_proj_bias_name="down_proj_bias",
|
||||||
|
)
|
||||||
|
|
||||||
|
params_dict = dict(self.named_parameters())
|
||||||
|
params_checker = {k: False for k, v in params_dict.items()}
|
||||||
|
for name, loaded_weight in weights:
|
||||||
|
loaded_weight = _WeightCreator.maybe_materialize(loaded_weight)
|
||||||
|
|
||||||
|
# Apply weight name mapping if provided
|
||||||
|
if weight_name_mapping and name in weight_name_mapping:
|
||||||
|
name = weight_name_mapping[name]
|
||||||
|
|
||||||
|
layer_id = get_layer_id(name)
|
||||||
|
if (
|
||||||
|
layer_id is not None
|
||||||
|
and hasattr(self.model, "start_layer")
|
||||||
|
and (
|
||||||
|
layer_id < self.model.start_layer
|
||||||
|
or layer_id >= self.model.end_layer
|
||||||
|
)
|
||||||
|
):
|
||||||
|
continue
|
||||||
|
|
||||||
|
if "rotary_emb.inv_freq" in name:
|
||||||
|
continue
|
||||||
|
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||||
|
if weight_name not in name:
|
||||||
|
continue
|
||||||
|
if "mlp.experts" in name:
|
||||||
|
continue
|
||||||
|
|
||||||
|
name = name.replace(weight_name, param_name)
|
||||||
|
if name.endswith(".bias") and name not in params_dict:
|
||||||
|
continue
|
||||||
|
if name not in params_dict:
|
||||||
|
continue
|
||||||
|
|
||||||
|
param = params_dict[name]
|
||||||
|
weight_loader = param.weight_loader
|
||||||
|
weight_loader(param, loaded_weight, shard_id)
|
||||||
|
params_checker[name] = True
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
for mapping in expert_params_mapping:
|
||||||
|
param_name, weight_name, shard_id = mapping
|
||||||
|
if weight_name not in name:
|
||||||
|
continue
|
||||||
|
name = name.replace(weight_name, param_name)
|
||||||
|
if name not in params_dict:
|
||||||
|
continue
|
||||||
|
param = params_dict[name]
|
||||||
|
weight_loader = param.weight_loader
|
||||||
|
if "bias" not in name:
|
||||||
|
loaded_weight = loaded_weight.transpose(-2, -1)
|
||||||
|
if "w2_weight_bias" in name and get_moe_tensor_parallel_rank() != 0:
|
||||||
|
loaded_weight = loaded_weight.zero_()
|
||||||
|
|
||||||
|
weight_loader(
|
||||||
|
param,
|
||||||
|
loaded_weight,
|
||||||
|
name,
|
||||||
|
shard_id=shard_id,
|
||||||
|
)
|
||||||
|
params_checker[name] = True
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
if name.endswith(".bias") and name not in params_dict:
|
||||||
|
continue
|
||||||
|
if name not in params_dict:
|
||||||
|
continue
|
||||||
|
if name in params_dict.keys():
|
||||||
|
param = params_dict[name]
|
||||||
|
if "sinks" in name:
|
||||||
|
start = tp_rank * param.numel()
|
||||||
|
param.data.copy_(
|
||||||
|
loaded_weight[start : start + param.numel()]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
weight_loader = getattr(
|
||||||
|
param, "weight_loader", default_weight_loader
|
||||||
|
)
|
||||||
|
weight_loader(param, loaded_weight)
|
||||||
|
params_checker[name] = True
|
||||||
|
else:
|
||||||
|
logger.warning(f"Parameter {name} not found in params_dict")
|
||||||
|
|
||||||
|
not_loaded_params = [k for k, v in params_checker.items() if not v]
|
||||||
|
if tp_rank == 0:
|
||||||
|
if len(not_loaded_params) > 0:
|
||||||
|
raise Exception(f"Not all parameters loaded: {not_loaded_params}")
|
||||||
|
else:
|
||||||
|
logging.info("All parameters loaded successfully.")
|
||||||
|
|
||||||
|
self.routed_experts_weights_of_layer = {
|
||||||
|
layer_id: self.model.layers[layer_id].mlp.get_moe_weights()
|
||||||
|
for layer_id in range(self.start_layer, self.end_layer)
|
||||||
|
if isinstance(self.model.layers[layer_id].mlp, GptOssSparseMoeBlock)
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_embed_and_head(self):
|
||||||
|
return self.model.embed_tokens.weight, self.lm_head.weight
|
||||||
|
|
||||||
|
def set_embed_and_head(self, embed, head):
|
||||||
|
del self.model.embed_tokens.weight
|
||||||
|
del self.lm_head.weight
|
||||||
|
self.model.embed_tokens.weight = embed
|
||||||
|
self.lm_head.weight = head
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
|
def set_eagle3_layers_to_capture(self, layer_ids: Optional[List[int]] = None):
|
||||||
|
if not self.pp_group.is_last_rank:
|
||||||
|
return
|
||||||
|
|
||||||
|
if layer_ids is None:
|
||||||
|
self.capture_aux_hidden_states = True
|
||||||
|
num_layers = self.config.num_hidden_layers
|
||||||
|
self.model.layers_to_capture = [2, num_layers // 2, num_layers - 3]
|
||||||
|
else:
|
||||||
|
self.capture_aux_hidden_states = True
|
||||||
|
# we plus 1 here because in sglang, for the ith layer, it takes the output
|
||||||
|
# of the (i-1)th layer as aux hidden state
|
||||||
|
self.model.layers_to_capture = [val + 1 for val in layer_ids]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_model_config_for_expert_location(cls, config):
|
||||||
|
return ModelConfigForExpertLocation(
|
||||||
|
num_layers=config.num_hidden_layers,
|
||||||
|
num_logical_experts=config.num_local_experts,
|
||||||
|
num_groups=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_attention_sliding_window_size(self):
|
||||||
|
return get_attention_sliding_window_size(self.config)
|
||||||
|
|
||||||
|
|
||||||
|
def _canonicalize_weights(config, weights_in: Iterable[Tuple[str, torch.Tensor]]):
|
||||||
|
weights_out_dict = dict(weights_in)
|
||||||
|
|
||||||
|
for layer_id in range(config.num_hidden_layers):
|
||||||
|
for name_chunk in ["mlp1_weight", "mlp2_weight"]:
|
||||||
|
name_prefix = f"block.{layer_id}.mlp.{name_chunk}"
|
||||||
|
w_blocks = weights_out_dict.pop(f"{name_prefix}.blocks", None)
|
||||||
|
w_scales = weights_out_dict.pop(f"{name_prefix}.scales", None)
|
||||||
|
if w_blocks is not None:
|
||||||
|
weights_out_dict[name_prefix] = _WeightCreator(
|
||||||
|
partial(
|
||||||
|
_dequant_mlp_weight,
|
||||||
|
debug_name=name_prefix,
|
||||||
|
w_blocks=w_blocks,
|
||||||
|
w_scales=w_scales,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return list(weights_out_dict.items())
|
||||||
|
|
||||||
|
|
||||||
|
def _dequant_mlp_weight(debug_name, w_blocks, w_scales):
|
||||||
|
if get_tensor_model_parallel_rank() == 0:
|
||||||
|
logger.info(f"Dequantize {debug_name} start")
|
||||||
|
|
||||||
|
original_device = w_blocks.device
|
||||||
|
|
||||||
|
w_blocks = w_blocks.cuda()
|
||||||
|
w_scales = w_scales.cuda()
|
||||||
|
|
||||||
|
w_bf16 = dequant_mxfp4(w_block=w_blocks, w_scale=w_scales, out_dtype=torch.bfloat16)
|
||||||
|
w_bf16 = w_bf16.transpose(-2, -1).contiguous()
|
||||||
|
|
||||||
|
if get_tensor_model_parallel_rank() == 0:
|
||||||
|
logger.info(
|
||||||
|
f"Dequantize {debug_name} end {w_blocks.shape=} {w_scales.shape=} {w_bf16.shape=}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return w_bf16.to(original_device)
|
||||||
|
|
||||||
|
|
||||||
|
class _WeightCreator:
|
||||||
|
def __init__(self, fn):
|
||||||
|
self._fn = fn
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def maybe_materialize(obj):
|
||||||
|
if isinstance(obj, _WeightCreator):
|
||||||
|
output = obj._fn()
|
||||||
|
obj._fn = None
|
||||||
|
return output
|
||||||
|
|
||||||
|
return obj
|
||||||
|
|
||||||
|
|
||||||
|
EntryClass = GptOssForCausalLM
|
||||||
@@ -457,6 +457,10 @@ class ServerArgs:
|
|||||||
raise ValueError(
|
raise ValueError(
|
||||||
"trtllm_mla backend does not support speculative decoding yet."
|
"trtllm_mla backend does not support speculative decoding yet."
|
||||||
)
|
)
|
||||||
|
model_arch = self.get_hf_config().architectures[0]
|
||||||
|
if model_arch in ["GptOssForCausalLM"]:
|
||||||
|
self.attention_backend = "triton"
|
||||||
|
self.enable_triton_kernel_moe = True
|
||||||
|
|
||||||
# Set page size
|
# Set page size
|
||||||
if self.page_size is None:
|
if self.page_size is None:
|
||||||
|
|||||||
Reference in New Issue
Block a user