[gpt-oss] Add gpt-oss bf16 support
This commit is contained in:
@@ -28,6 +28,7 @@ def kernel_paged_attention_2d(
|
||||
query_ptr, # [num_tokens, num_query_heads, head_size]
|
||||
key_cache_ptr, # [num_blks, num_kv_heads, head_size // x, blk_size, x]
|
||||
value_cache_ptr, # [num_blks, num_kv_heads, head_size, blk_size]
|
||||
sink_ptr, # [num_query_heads]
|
||||
block_tables_ptr, # [num_seqs, max_num_blocks_per_seq]
|
||||
seq_lens_ptr, # [num_seqs]
|
||||
alibi_slopes_ptr, # [num_query_heads]
|
||||
@@ -59,6 +60,7 @@ def kernel_paged_attention_2d(
|
||||
stride_v_cache_3: tl.int64, # int
|
||||
filter_by_query_len: tl.constexpr, # bool
|
||||
query_start_len_ptr, # [num_seqs+1]
|
||||
USE_SINKS: tl.constexpr, # bool
|
||||
):
|
||||
seq_idx = tl.program_id(0)
|
||||
kv_head_idx = tl.program_id(1)
|
||||
@@ -95,7 +97,18 @@ def kernel_paged_attention_2d(
|
||||
|
||||
block_table_offset = seq_idx * block_table_stride
|
||||
|
||||
M = tl.full([num_queries_per_kv_padded], float("-inf"), dtype=tl.float32)
|
||||
if not USE_SINKS:
|
||||
M = tl.full([num_queries_per_kv_padded],
|
||||
float("-inf"),
|
||||
dtype=tl.float32)
|
||||
else:
|
||||
M = tl.load(
|
||||
sink_ptr + query_head_idx,
|
||||
mask=head_mask,
|
||||
other=float("-inf"),
|
||||
).to(dtype=tl.float32)
|
||||
# M = tl.full([num_queries_per_kv_padded], float("-inf"), dtype=tl.float32)
|
||||
|
||||
L = tl.full([num_queries_per_kv_padded], 1.0, dtype=tl.float32)
|
||||
acc = tl.zeros([num_queries_per_kv_padded, HEAD_SIZE_PADDED],
|
||||
dtype=tl.float32)
|
||||
@@ -223,6 +236,8 @@ def chunked_prefill_paged_decode(
|
||||
alibi_slopes=None,
|
||||
sliding_window=None,
|
||||
sm_scale=None,
|
||||
# Optional tensor for sinks
|
||||
sinks=None,
|
||||
):
|
||||
|
||||
if sm_scale is None:
|
||||
@@ -253,6 +268,7 @@ def chunked_prefill_paged_decode(
|
||||
sliding_window=sliding_window,
|
||||
sm_scale=sm_scale,
|
||||
skip_decode=True,
|
||||
sinks=sinks,
|
||||
)
|
||||
|
||||
block_size = value_cache.shape[3]
|
||||
@@ -285,7 +301,7 @@ def chunked_prefill_paged_decode(
|
||||
block_size,
|
||||
num_queries_per_kv,
|
||||
max_seq_len, sliding_window,
|
||||
kv_cache_dtype, alibi_slopes)
|
||||
kv_cache_dtype, alibi_slopes, sinks,)
|
||||
if use_custom:
|
||||
_PARTITION_SIZE_ROCM = 256
|
||||
max_num_partitions = ((max_seq_len + _PARTITION_SIZE_ROCM - 1) //
|
||||
@@ -334,6 +350,7 @@ def chunked_prefill_paged_decode(
|
||||
query_ptr=query,
|
||||
key_cache_ptr=key_cache,
|
||||
value_cache_ptr=value_cache,
|
||||
sink_ptr=sinks,
|
||||
block_tables_ptr=block_table,
|
||||
seq_lens_ptr=seq_lens,
|
||||
alibi_slopes_ptr=alibi_slopes,
|
||||
@@ -365,4 +382,5 @@ def chunked_prefill_paged_decode(
|
||||
stride_v_cache_3=value_cache.stride(3),
|
||||
filter_by_query_len=True,
|
||||
query_start_len_ptr=query_start_loc,
|
||||
USE_SINKS=sinks is not None,
|
||||
)
|
||||
|
||||
@@ -34,6 +34,7 @@ def kernel_unified_attention_2d(
|
||||
query_ptr, # [num_tokens, num_query_heads, head_size]
|
||||
key_cache_ptr, # [num_blks, blk_size, num_kv_heads, head_size]
|
||||
value_cache_ptr, # [num_blks, blk_size, num_kv_heads, head_size]
|
||||
sink_ptr, # [num_query_heads]
|
||||
block_tables_ptr, # [num_seqs, max_num_blocks_per_seq]
|
||||
seq_lens_ptr, # [num_seqs]
|
||||
alibi_slopes_ptr, # [num_query_heads]
|
||||
@@ -53,6 +54,7 @@ def kernel_unified_attention_2d(
|
||||
HEAD_SIZE_PADDED: tl.constexpr, # int, must be power of 2
|
||||
USE_ALIBI_SLOPES: tl.constexpr, # bool
|
||||
USE_SOFTCAP: tl.constexpr, # bool
|
||||
USE_SINKS: tl.constexpr, # bool
|
||||
SLIDING_WINDOW: tl.constexpr, # int
|
||||
stride_k_cache_0: tl.int64, # int
|
||||
stride_k_cache_1: tl.int64, # int
|
||||
@@ -119,7 +121,16 @@ def kernel_unified_attention_2d(
|
||||
|
||||
block_table_offset = seq_idx * block_table_stride
|
||||
|
||||
M = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32)
|
||||
if not USE_SINKS:
|
||||
M = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32)
|
||||
else:
|
||||
M = tl.load(
|
||||
sink_ptr + query_offset_1,
|
||||
mask=query_mask_1,
|
||||
other=float("-inf"),
|
||||
).to(dtype=tl.float32)
|
||||
# M = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32)
|
||||
|
||||
L = tl.full([BLOCK_M], 1.0, dtype=tl.float32)
|
||||
acc = tl.zeros([BLOCK_M, HEAD_SIZE_PADDED], dtype=tl.float32)
|
||||
|
||||
@@ -260,6 +271,8 @@ def unified_attention(
|
||||
k_descale,
|
||||
v_descale,
|
||||
alibi_slopes=None,
|
||||
# Optional tensor for sinks
|
||||
sinks=None,
|
||||
):
|
||||
assert causal, "Only causal attention is supported"
|
||||
assert q_descale is None, "Q scales not supported"
|
||||
@@ -267,6 +280,10 @@ def unified_attention(
|
||||
block_size = v.shape[1]
|
||||
assert q.element_size() >= 2 or block_size >= 32, \
|
||||
"Block size must be at least 32 for fp8"
|
||||
|
||||
if sinks is not None:
|
||||
assert sinks.shape[0] == q.shape[1], \
|
||||
"Sinks must be num_query_heads size"
|
||||
|
||||
use_alibi_slopes = alibi_slopes is not None
|
||||
|
||||
@@ -299,6 +316,7 @@ def unified_attention(
|
||||
query_ptr=q,
|
||||
key_cache_ptr=k,
|
||||
value_cache_ptr=v,
|
||||
sink_ptr=sinks,
|
||||
block_tables_ptr=block_table,
|
||||
seq_lens_ptr=seqused_k,
|
||||
alibi_slopes_ptr=alibi_slopes,
|
||||
@@ -318,6 +336,7 @@ def unified_attention(
|
||||
HEAD_SIZE_PADDED=triton.next_power_of_2(head_size),
|
||||
USE_ALIBI_SLOPES=use_alibi_slopes,
|
||||
USE_SOFTCAP=(softcap > 0),
|
||||
USE_SINKS=(sinks is not None),
|
||||
SLIDING_WINDOW=(1 + window_size[0]),
|
||||
stride_k_cache_0=k.stride(0),
|
||||
stride_k_cache_1=k.stride(1),
|
||||
|
||||
@@ -275,6 +275,7 @@ def fused_moe_kernel(
|
||||
a_ptr,
|
||||
b_ptr,
|
||||
c_ptr,
|
||||
b_bias_ptr,
|
||||
a_scale_ptr,
|
||||
b_scale_ptr,
|
||||
topk_weights_ptr,
|
||||
@@ -303,6 +304,8 @@ def fused_moe_kernel(
|
||||
stride_bse,
|
||||
stride_bsk,
|
||||
stride_bsn,
|
||||
stride_bbe, # bias expert stride
|
||||
stride_bbn, # bias N stride
|
||||
# Block size for block-wise quantization
|
||||
group_n: tl.constexpr,
|
||||
group_k: tl.constexpr,
|
||||
@@ -321,6 +324,7 @@ def fused_moe_kernel(
|
||||
use_int8_w8a8: tl.constexpr,
|
||||
use_int8_w8a16: tl.constexpr,
|
||||
per_channel_quant: tl.constexpr,
|
||||
HAS_BIAS: tl.constexpr,
|
||||
UPGRADE: tl.constexpr,
|
||||
UPGRADE_A_OFFS: tl.constexpr,
|
||||
UPGRADE_B_OFFS: tl.constexpr,
|
||||
@@ -447,6 +451,10 @@ def fused_moe_kernel(
|
||||
else:
|
||||
a_scale = tl.load(a_scale_ptr)
|
||||
b_scale = tl.load(b_scale_ptr + off_experts)
|
||||
if HAS_BIAS:
|
||||
# bias shape: [num_experts, N]
|
||||
bias_ptrs = b_bias_ptr + off_experts * stride_bbe + offs_bn * stride_bbn
|
||||
bias = tl.load(bias_ptrs, mask=(offs_bn < N), other=0.0)
|
||||
|
||||
# -----------------------------------------------------------
|
||||
# Iterate to compute a block of the C matrix.
|
||||
@@ -494,7 +502,8 @@ def fused_moe_kernel(
|
||||
# Advance the ptrs to the next K block.
|
||||
a_ptrs += BLOCK_SIZE_K * stride_ak * SPLIT_K
|
||||
b_ptrs += BLOCK_SIZE_K * stride_bk * SPLIT_K
|
||||
|
||||
if HAS_BIAS:
|
||||
accumulator = accumulator + bias[None, :]
|
||||
if MUL_ROUTED_WEIGHT:
|
||||
moe_weight = tl.load(topk_weights_ptr + offs_token,
|
||||
mask=token_mask,
|
||||
@@ -548,7 +557,8 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
|
||||
use_int4_w4a16: bool,
|
||||
orig_acc_dtype: torch.dtype,
|
||||
per_channel_quant: bool,
|
||||
block_shape: Optional[list[int]] = None) -> None:
|
||||
block_shape: Optional[list[int]] = None,
|
||||
B_bias: Optional[torch.Tensor] = None) -> None:
|
||||
assert topk_weights is not None or not mul_routed_weight
|
||||
assert topk_weights is None or topk_weights.stride(1) == 1
|
||||
assert sorted_token_ids.stride(0) == 1
|
||||
@@ -580,7 +590,7 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
|
||||
A.shape[0] * top_k * config['BLOCK_SIZE_M'])
|
||||
grid = lambda META: (triton.cdiv(EM, META['BLOCK_SIZE_M']) * triton.cdiv(
|
||||
B.shape[1], META['BLOCK_SIZE_N']), META['SPLIT_K'])
|
||||
|
||||
HAS_BIAS = B_bias is not None
|
||||
if (use_int8_w8a16 or use_int4_w4a16) and \
|
||||
block_shape is not None and block_shape[1] > 0:
|
||||
assert B_scale is not None and B_scale.ndim == 3
|
||||
@@ -592,19 +602,19 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
|
||||
num_experts=B.shape[0],
|
||||
bit=4 if use_int4_w4a16 else 8)
|
||||
# TODO: missing config for BLOCK_SIZE_K
|
||||
# config = config.copy()
|
||||
# config.update(
|
||||
# get_moe_wna16_block_config(config=config,
|
||||
# use_moe_wna16_cuda=use_moe_wna16_cuda,
|
||||
# num_valid_tokens=num_tokens,
|
||||
# size_k=A.shape[1],
|
||||
# size_n=B.shape[1],
|
||||
# num_experts=B.shape[1],
|
||||
# group_size=block_shape[1],
|
||||
# real_top_k=top_k,
|
||||
# block_size_m=config["BLOCK_SIZE_M"]))
|
||||
config = config.copy()
|
||||
config.update(
|
||||
get_moe_wna16_block_config(config=config,
|
||||
use_moe_wna16_cuda=use_moe_wna16_cuda,
|
||||
num_valid_tokens=num_tokens,
|
||||
size_k=A.shape[1],
|
||||
size_n=B.shape[1],
|
||||
num_experts=B.shape[1],
|
||||
group_size=block_shape[1],
|
||||
real_top_k=top_k,
|
||||
block_size_m=config["BLOCK_SIZE_M"]))
|
||||
|
||||
if False and use_moe_wna16_cuda:
|
||||
if use_moe_wna16_cuda:
|
||||
bit = 4 if use_int4_w4a16 else 8
|
||||
ops.moe_wna16_gemm(A, C, B, B_scale, B_zp,
|
||||
topk_weights if mul_routed_weight else None,
|
||||
@@ -661,6 +671,7 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
|
||||
A,
|
||||
B,
|
||||
C,
|
||||
B_bias,
|
||||
A_scale,
|
||||
B_scale,
|
||||
topk_weights,
|
||||
@@ -689,6 +700,8 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
|
||||
if B_scale is not None and B_scale.ndim == 3 else 0,
|
||||
B_scale.stride(1)
|
||||
if B_scale is not None and B_scale.ndim >= 2 else 0,
|
||||
B_bias.stride(0) if B_bias is not None else 0,
|
||||
B_bias.stride(1) if B_bias is not None else 0,
|
||||
0 if block_shape is None else block_shape[0],
|
||||
0 if block_shape is None else block_shape[1],
|
||||
MUL_ROUTED_WEIGHT=mul_routed_weight,
|
||||
@@ -699,6 +712,7 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
|
||||
use_int8_w8a8=use_int8_w8a8,
|
||||
use_int8_w8a16=use_int8_w8a16,
|
||||
per_channel_quant=per_channel_quant,
|
||||
HAS_BIAS=HAS_BIAS,
|
||||
BLOCK_SIZE_K=BLOCK_SIZE_K,
|
||||
FAST_F32_TO_BF16 = True,
|
||||
**config,
|
||||
@@ -1103,13 +1117,15 @@ def inplace_fused_experts(hidden_states: torch.Tensor,
|
||||
w2_zp: Optional[torch.Tensor] = None,
|
||||
a1_scale: Optional[torch.Tensor] = None,
|
||||
a2_scale: Optional[torch.Tensor] = None,
|
||||
block_shape: Optional[List[int]] = None) -> None:
|
||||
block_shape: Optional[List[int]] = None,
|
||||
w1_bias: Optional[torch.Tensor] = None,
|
||||
w2_bias: Optional[torch.Tensor] = None) -> None:
|
||||
fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids, True,
|
||||
activation, apply_router_weight_on_input, use_fp8_w8a8,
|
||||
use_int8_w8a8, use_int8_w8a16, use_int4_w4a16,
|
||||
per_channel_quant, global_num_experts, expert_map,
|
||||
w1_scale, w2_scale, w1_zp, w2_zp, a1_scale, a2_scale,
|
||||
block_shape)
|
||||
block_shape, w1_bias, w2_bias)
|
||||
|
||||
|
||||
def inplace_fused_experts_fake(
|
||||
@@ -1133,7 +1149,9 @@ def inplace_fused_experts_fake(
|
||||
w2_zp: Optional[torch.Tensor] = None,
|
||||
a1_scale: Optional[torch.Tensor] = None,
|
||||
a2_scale: Optional[torch.Tensor] = None,
|
||||
block_shape: Optional[List[int]] = None) -> None:
|
||||
block_shape: Optional[List[int]] = None,
|
||||
w1_bias: Optional[torch.Tensor] = None,
|
||||
w2_bias: Optional[torch.Tensor] = None) -> None:
|
||||
pass
|
||||
|
||||
|
||||
@@ -1167,14 +1185,16 @@ def outplace_fused_experts(
|
||||
w2_zp: Optional[torch.Tensor] = None,
|
||||
a1_scale: Optional[torch.Tensor] = None,
|
||||
a2_scale: Optional[torch.Tensor] = None,
|
||||
block_shape: Optional[List[int]] = None) -> torch.Tensor:
|
||||
block_shape: Optional[List[int]] = None,
|
||||
w1_bias: Optional[torch.Tensor] = None,
|
||||
w2_bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
return fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids,
|
||||
False, activation, apply_router_weight_on_input,
|
||||
use_fp8_w8a8, use_int8_w8a8, use_int8_w8a16,
|
||||
use_int4_w4a16, per_channel_quant,
|
||||
global_num_experts, expert_map, w1_scale,
|
||||
w2_scale, w1_zp, w2_zp, a1_scale, a2_scale,
|
||||
block_shape)
|
||||
block_shape, w1_bias, w2_bias)
|
||||
|
||||
|
||||
def outplace_fused_experts_fake(
|
||||
@@ -1197,7 +1217,9 @@ def outplace_fused_experts_fake(
|
||||
w2_zp: Optional[torch.Tensor] = None,
|
||||
a1_scale: Optional[torch.Tensor] = None,
|
||||
a2_scale: Optional[torch.Tensor] = None,
|
||||
block_shape: Optional[List[int]] = None) -> torch.Tensor:
|
||||
block_shape: Optional[List[int]] = None,
|
||||
w1_bias: Optional[torch.Tensor] = None,
|
||||
w2_bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
return torch.empty_like(hidden_states)
|
||||
|
||||
|
||||
@@ -1248,7 +1270,9 @@ def fused_experts(hidden_states: torch.Tensor,
|
||||
a1_scale: Optional[torch.Tensor] = None,
|
||||
a2_scale: Optional[torch.Tensor] = None,
|
||||
block_shape: Optional[list[int]] = None,
|
||||
allow_deep_gemm: bool = False) -> torch.Tensor:
|
||||
allow_deep_gemm: bool = False,
|
||||
w1_bias: Optional[torch.Tensor] = None,
|
||||
w2_bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
# For now, disable DeepGemm for small N (<= 512) until better
|
||||
# permute/unpermute ops are available.
|
||||
N = w1.shape[1]
|
||||
@@ -1293,7 +1317,10 @@ def fused_experts(hidden_states: torch.Tensor,
|
||||
w2_zp=w2_zp,
|
||||
a1_scale=a1_scale,
|
||||
a2_scale=a2_scale,
|
||||
block_shape=block_shape)
|
||||
block_shape=block_shape,
|
||||
w1_bias=w1_bias,
|
||||
w2_bias=w2_bias,
|
||||
)
|
||||
|
||||
|
||||
def fused_experts_impl(
|
||||
@@ -1319,6 +1346,8 @@ def fused_experts_impl(
|
||||
a1_scale: Optional[torch.Tensor] = None,
|
||||
a2_scale: Optional[torch.Tensor] = None,
|
||||
block_shape: Optional[list[int]] = None,
|
||||
w1_bias: Optional[torch.Tensor] = None,
|
||||
w2_bias: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
# Check constraints.
|
||||
if use_int4_w4a16:
|
||||
@@ -1498,7 +1527,19 @@ def fused_experts_impl(
|
||||
use_int4_w4a16=use_int4_w4a16,
|
||||
orig_acc_dtype=hidden_states.dtype,
|
||||
per_channel_quant=per_channel_quant,
|
||||
block_shape=block_shape)
|
||||
block_shape=block_shape,
|
||||
B_bias=w1_bias)
|
||||
|
||||
# TODO fused kernel
|
||||
def swiglu_oai(gate_up):
|
||||
alpha = 1.702
|
||||
limit = 7.0
|
||||
gate, up = gate_up[..., ::2], gate_up[..., 1::2]
|
||||
gate = gate.clamp(min=None, max=limit)
|
||||
up = up.clamp(min=-limit, max=limit)
|
||||
glu = gate * torch.sigmoid(gate * alpha)
|
||||
gated_output = (up + 1) * glu
|
||||
return gated_output
|
||||
|
||||
if activation == "silu":
|
||||
torch.ops._C.silu_and_mul(intermediate_cache2,
|
||||
@@ -1506,6 +1547,8 @@ def fused_experts_impl(
|
||||
elif activation == "gelu":
|
||||
torch.ops._C.gelu_and_mul(intermediate_cache2,
|
||||
intermediate_cache1.view(-1, N))
|
||||
elif activation == "swiglu_oai":
|
||||
intermediate_cache2 = swiglu_oai(intermediate_cache1.view(-1, N))
|
||||
else:
|
||||
raise ValueError(f"Unsupported FusedMoe activation: {activation}")
|
||||
|
||||
@@ -1543,7 +1586,8 @@ def fused_experts_impl(
|
||||
use_int4_w4a16=use_int4_w4a16,
|
||||
orig_acc_dtype=hidden_states.dtype,
|
||||
per_channel_quant=per_channel_quant,
|
||||
block_shape=block_shape)
|
||||
block_shape=block_shape,
|
||||
B_bias=w2_bias)
|
||||
|
||||
ops.moe_sum(intermediate_cache3.view(*intermediate_cache3.shape),
|
||||
out_hidden_states[begin_chunk_idx:end_chunk_idx])
|
||||
@@ -1578,6 +1622,8 @@ def fused_moe(
|
||||
a1_scale: Optional[torch.Tensor] = None,
|
||||
a2_scale: Optional[torch.Tensor] = None,
|
||||
block_shape: Optional[list[int]] = None,
|
||||
w1_bias: Optional[torch.Tensor] = None,
|
||||
w2_bias: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
This function computes a Mixture of Experts (MoE) layer using two sets of
|
||||
@@ -1661,7 +1707,9 @@ def fused_moe(
|
||||
w2_zp=w2_zp,
|
||||
a1_scale=a1_scale,
|
||||
a2_scale=a2_scale,
|
||||
block_shape=block_shape)
|
||||
block_shape=block_shape,
|
||||
w1_bias=w1_bias,
|
||||
w2_bias=w2_bias)
|
||||
|
||||
|
||||
class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
@@ -1805,7 +1853,9 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
use_int8_w8a16=self.use_int8_w8a16,
|
||||
use_int4_w4a16=self.use_int4_w4a16,
|
||||
per_channel_quant=self.per_channel_quant,
|
||||
block_shape=self.block_shape)
|
||||
block_shape=self.block_shape,
|
||||
B_bias=None # TODO support B_bias
|
||||
)
|
||||
|
||||
self.activation(activation, intermediate_cache2,
|
||||
intermediate_cache1.view(-1, N))
|
||||
@@ -1835,7 +1885,9 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
use_int8_w8a16=self.use_int8_w8a16,
|
||||
use_int4_w4a16=self.use_int4_w4a16,
|
||||
per_channel_quant=self.per_channel_quant,
|
||||
block_shape=self.block_shape)
|
||||
block_shape=self.block_shape,
|
||||
B_bias=None # TODO support B_bias
|
||||
)
|
||||
|
||||
return intermediate_cache3
|
||||
|
||||
|
||||
@@ -226,6 +226,8 @@ class MoEConfig:
|
||||
|
||||
max_num_tokens: int = MOE_DP_CHUNK_SIZE
|
||||
|
||||
has_bias: bool = False
|
||||
|
||||
@property
|
||||
def tp_size(self):
|
||||
return self.moe_parallel_config.tp_size
|
||||
@@ -443,6 +445,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
self.fused_experts = fused_experts # type: ignore
|
||||
self.topk_indices_dtype = None
|
||||
self.moe = moe
|
||||
self.has_bias = self.moe.has_bias
|
||||
|
||||
self.rocm_aiter_moe_enabled = is_rocm_aiter_moe_enabled()
|
||||
if self.rocm_aiter_moe_enabled:
|
||||
@@ -502,6 +505,14 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w13_weight", w13_weight)
|
||||
set_weight_attrs(w13_weight, extra_weight_attrs)
|
||||
if self.has_bias:
|
||||
w13_bias = torch.nn.Parameter(torch.zeros(
|
||||
num_experts,
|
||||
2 * intermediate_size_per_partition,
|
||||
dtype=params_dtype),
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w13_bias", w13_bias)
|
||||
set_weight_attrs(w13_bias, extra_weight_attrs)
|
||||
|
||||
# down_proj (row parallel)
|
||||
w2_weight = torch.nn.Parameter(torch.empty(
|
||||
@@ -512,6 +523,13 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w2_weight", w2_weight)
|
||||
set_weight_attrs(w2_weight, extra_weight_attrs)
|
||||
if self.has_bias:
|
||||
w2_bias = torch.nn.Parameter(torch.zeros(num_experts,
|
||||
hidden_size,
|
||||
dtype=params_dtype),
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w2_bias", w2_bias)
|
||||
set_weight_attrs(w2_bias, extra_weight_attrs)
|
||||
|
||||
def _maybe_pad_weight(self, weight: torch.Tensor) -> torch.Tensor:
|
||||
# Pad the weight tensor. This is an optimization on ROCm platform, which
|
||||
@@ -634,6 +652,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
w2=layer.w2_weight,
|
||||
w1_bias=layer.w13_bias if self.has_bias else None,
|
||||
w2_bias=layer.w2_bias if self.has_bias else None,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
inplace=True,
|
||||
@@ -840,6 +860,7 @@ class FusedMoE(torch.nn.Module):
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
activation: str = "silu",
|
||||
has_bias: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
if params_dtype is None:
|
||||
@@ -920,6 +941,7 @@ class FusedMoE(torch.nn.Module):
|
||||
in_dtype=params_dtype,
|
||||
quant_dtype=quant_dtype,
|
||||
max_num_tokens=MOE_DP_CHUNK_SIZE,
|
||||
has_bias=has_bias,
|
||||
)
|
||||
self.moe_config = moe
|
||||
self.quant_config = quant_config
|
||||
|
||||
618
model_executor/models/gpt_oss.py
Normal file
618
model_executor/models/gpt_oss.py
Normal file
@@ -0,0 +1,618 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from collections.abc import Iterable
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch import nn
|
||||
from transformers import GptOssConfig
|
||||
|
||||
from vllm import envs
|
||||
from vllm.attention import Attention, AttentionType
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import CacheConfig, VllmConfig
|
||||
from vllm.distributed import (get_ep_group, get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size)
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import (QKVParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
ParallelLMHead, VocabParallelEmbedding)
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils import cdiv
|
||||
|
||||
from .utils import extract_layer_index, maybe_prefix
|
||||
|
||||
|
||||
class OAIAttention(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: GptOssConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
self.layer_idx = extract_layer_index(prefix)
|
||||
self.head_dim = config.head_dim
|
||||
self.num_attention_heads = config.num_attention_heads
|
||||
self.num_key_value_heads = config.num_key_value_heads
|
||||
self.hidden_size = config.hidden_size
|
||||
|
||||
self.rotary_emb = get_rope(
|
||||
self.head_dim,
|
||||
rotary_dim=self.head_dim,
|
||||
max_position=config.max_position_embeddings,
|
||||
base=config.rope_theta,
|
||||
dtype=torch.float32,
|
||||
rope_scaling={
|
||||
"rope_type":
|
||||
"yarn",
|
||||
"factor":
|
||||
config.rope_scaling["factor"],
|
||||
"original_max_position_embeddings":
|
||||
config.rope_scaling["original_max_position_embeddings"],
|
||||
"beta_fast":
|
||||
config.rope_scaling["beta_fast"],
|
||||
"beta_slow":
|
||||
config.rope_scaling["beta_slow"],
|
||||
},
|
||||
is_neox_style=True,
|
||||
)
|
||||
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
|
||||
# attention_sink_dtype = (torch.float32 if envs.VLLM_USE_TRTLLM_ATTENTION
|
||||
# else torch.bfloat16)
|
||||
attention_sink_dtype = torch.bfloat16
|
||||
self.sinks = torch.nn.Parameter(
|
||||
torch.empty(config.num_attention_heads // tp_size,
|
||||
dtype=attention_sink_dtype,
|
||||
requires_grad=False))
|
||||
|
||||
self.norm = RMSNorm(config.hidden_size, eps=1e-5)
|
||||
|
||||
self.q_size = self.num_attention_heads * self.head_dim // tp_size
|
||||
self.kv_size = self.num_key_value_heads * self.head_dim // tp_size
|
||||
self.scaling = self.head_dim**-0.5
|
||||
self.rope_theta = config.rope_theta
|
||||
|
||||
self.qkv = QKVParallelLinear(
|
||||
hidden_size=self.hidden_size,
|
||||
head_size=self.head_dim,
|
||||
total_num_heads=self.num_attention_heads,
|
||||
total_num_kv_heads=self.num_key_value_heads,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.qkv_proj",
|
||||
)
|
||||
|
||||
self.o_proj = RowParallelLinear(
|
||||
input_size=self.num_attention_heads * self.head_dim,
|
||||
output_size=self.hidden_size,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.o_proj",
|
||||
)
|
||||
|
||||
self.num_local_attention_heads = config.num_attention_heads // tp_size
|
||||
self.num_local_key_value_heads = config.num_key_value_heads // tp_size
|
||||
|
||||
# Only apply sliding window to every other layer
|
||||
sliding_window = (config.sliding_window if self.layer_idx %
|
||||
2 == 0 else None)
|
||||
self.attn = Attention(
|
||||
self.num_local_attention_heads,
|
||||
self.head_dim,
|
||||
self.scaling,
|
||||
num_kv_heads=self.num_local_key_value_heads,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
per_layer_sliding_window=sliding_window,
|
||||
attn_type=AttentionType.DECODER,
|
||||
prefix=f"{prefix}.attn",
|
||||
sinks=self.sinks,
|
||||
)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor,
|
||||
positions: torch.Tensor) -> torch.Tensor:
|
||||
t = self.norm(hidden_states)
|
||||
|
||||
qkv, _ = self.qkv(t)
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
q, k = self.rotary_emb(positions, q, k)
|
||||
v = v.contiguous()
|
||||
attn_output = self.attn(q, k, v)
|
||||
output, _ = self.o_proj(attn_output)
|
||||
|
||||
return output + hidden_states
|
||||
|
||||
|
||||
class MLPBlock(torch.nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: GptOssConfig,
|
||||
layer_idx: int,
|
||||
quant_config: QuantizationConfig,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
self.layer_idx = layer_idx
|
||||
self.num_experts = config.num_local_experts
|
||||
self.experts_per_token = config.num_experts_per_tok
|
||||
self.world_size = dist.get_world_size() if dist.is_initialized() else 1
|
||||
self.norm = RMSNorm(config.hidden_size, eps=1e-5)
|
||||
self.router = torch.nn.Linear(config.hidden_size,
|
||||
config.num_local_experts,
|
||||
dtype=torch.bfloat16)
|
||||
assert config.intermediate_size % self.world_size == 0
|
||||
self.experts = FusedMoE(num_experts=config.num_local_experts,
|
||||
top_k=config.num_experts_per_tok,
|
||||
hidden_size=config.hidden_size,
|
||||
intermediate_size=config.intermediate_size,
|
||||
reduce_results=True,
|
||||
renormalize=True,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.experts",
|
||||
apply_router_weight_on_input=False,
|
||||
has_bias=True,
|
||||
activation="swiglu_oai")
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
t = self.norm(x)
|
||||
g = self.router(t)
|
||||
t = self.experts(hidden_states=t, router_logits=g)
|
||||
return x + t
|
||||
|
||||
|
||||
class TransformerBlock(torch.nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: GptOssConfig,
|
||||
quant_config: QuantizationConfig,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
self.layer_idx = extract_layer_index(prefix)
|
||||
self.attn = OAIAttention(config, prefix=f"{prefix}.attn")
|
||||
self.mlp = MLPBlock(config,
|
||||
self.layer_idx,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.mlp")
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor,
|
||||
positions: torch.Tensor) -> torch.Tensor:
|
||||
attn_output = self.attn(hidden_states, positions)
|
||||
output = self.mlp(attn_output)
|
||||
return output
|
||||
|
||||
|
||||
@support_torch_compile
|
||||
class GptOssModel(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
self.config = vllm_config.model_config.hf_config
|
||||
self.quant_config = vllm_config.quant_config
|
||||
self.config.hidden_size = self.config.hidden_size
|
||||
self.embedding = VocabParallelEmbedding(
|
||||
self.config.vocab_size,
|
||||
self.config.hidden_size,
|
||||
)
|
||||
self.layers = torch.nn.ModuleList([
|
||||
TransformerBlock(
|
||||
self.config,
|
||||
quant_config=self.quant_config,
|
||||
prefix=maybe_prefix(prefix, f"block.{layer_idx}"),
|
||||
) for layer_idx in range(self.config.num_hidden_layers)
|
||||
])
|
||||
self.norm = RMSNorm(self.config.hidden_size, eps=1e-5)
|
||||
|
||||
def forward(self, input_ids: torch.Tensor,
|
||||
positions: torch.Tensor) -> torch.Tensor:
|
||||
x = self.embedding(input_ids)
|
||||
for layer in self.layers:
|
||||
x = layer(x, positions)
|
||||
x = self.norm(x)
|
||||
return x
|
||||
|
||||
|
||||
class GptOssForCausalLM(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
self.vllm_config = vllm_config
|
||||
self.model_config = vllm_config.model_config.hf_config
|
||||
self.model = GptOssModel(
|
||||
vllm_config=vllm_config,
|
||||
prefix=maybe_prefix(prefix, "model"),
|
||||
)
|
||||
self.lm_head = ParallelLMHead(
|
||||
self.model_config.vocab_size,
|
||||
self.model_config.hidden_size,
|
||||
)
|
||||
self.logits_processor = LogitsProcessor(self.model_config.vocab_size)
|
||||
|
||||
def forward(self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
assert intermediate_tensors is None
|
||||
assert inputs_embeds is None
|
||||
return self.model(input_ids, positions)
|
||||
|
||||
def compute_logits(self, hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata) -> torch.Tensor:
|
||||
logits = self.logits_processor(self.lm_head, hidden_states,
|
||||
sampling_metadata)
|
||||
return logits
|
||||
|
||||
def _load_weights_mxfp4(
|
||||
self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
|
||||
rename_mapping = {
|
||||
"self_attn": "attn",
|
||||
"input_layernorm.weight": "attn.norm.weight",
|
||||
"post_attention_layernorm.weight": "mlp.norm.weight",
|
||||
"embed_tokens": "embedding",
|
||||
}
|
||||
|
||||
def maybe_rename(name: str) -> str:
|
||||
for remap_name, new_name in rename_mapping.items():
|
||||
if remap_name in name:
|
||||
return name.replace(remap_name, new_name)
|
||||
return name
|
||||
|
||||
params_dict = dict(self.named_parameters())
|
||||
loaded_params: set[str] = set()
|
||||
mxfp4_block = 32
|
||||
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
intermediate_size = self.model_config.intermediate_size
|
||||
intermediate_size_block = intermediate_size // mxfp4_block
|
||||
per_rank_intermediate_size_block = cdiv(intermediate_size_block,
|
||||
tp_size)
|
||||
per_rank_intermediate_size = (per_rank_intermediate_size_block *
|
||||
mxfp4_block)
|
||||
|
||||
# Calculate common slicing bounds for current rank
|
||||
tp_rank_start = tp_rank * per_rank_intermediate_size
|
||||
tp_rank_end = min((tp_rank + 1) * per_rank_intermediate_size,
|
||||
intermediate_size)
|
||||
|
||||
# Attention heads per rank
|
||||
heads_per_rank = self.model_config.num_attention_heads // tp_size
|
||||
head_start = tp_rank * heads_per_rank
|
||||
|
||||
use_ep = self.vllm_config.parallel_config.enable_expert_parallel
|
||||
ep_size = get_ep_group().world_size
|
||||
ep_rank = get_ep_group().rank
|
||||
num_experts = self.model_config.num_local_experts
|
||||
experts_per_rank = num_experts // ep_size
|
||||
ep_rank_start = ep_rank * experts_per_rank
|
||||
ep_rank_end = (ep_rank + 1) * experts_per_rank
|
||||
|
||||
for name, weight in weights:
|
||||
# FIXME(woosuk): Remove this after testing.
|
||||
weight = weight.cuda()
|
||||
|
||||
if "gate_up_proj_blocks" in name:
|
||||
# Handle MLP gate and up projection weights
|
||||
new_name = name.replace("gate_up_proj_blocks", "w13_weight")
|
||||
|
||||
# flat weight from (E, 2 * N, block_size, entry_per_block)
|
||||
# to (E, 2 * N, -1), shouldn't trigger copy for contiguous
|
||||
weight = weight.view(num_experts, 2 * intermediate_size,
|
||||
-1).contiguous()
|
||||
|
||||
# Extract gate and up projection parts
|
||||
# since the weight is shuffled, we can slice directly
|
||||
if use_ep:
|
||||
narrow_weight = weight[ep_rank_start:ep_rank_end, ...]
|
||||
else:
|
||||
narrow_weight = weight[:,
|
||||
2 * tp_rank_start:2 * tp_rank_end,
|
||||
...]
|
||||
|
||||
param = params_dict[new_name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param,
|
||||
narrow_weight,
|
||||
weight_name=new_name,
|
||||
shard_id=None,
|
||||
expert_id=None)
|
||||
loaded_params.add(new_name)
|
||||
|
||||
elif "down_proj_blocks" in name:
|
||||
# Handle MLP down projection weights
|
||||
new_name = name.replace("down_proj_blocks", "w2_weight")
|
||||
# same flatten here, but since 2 mx4 value are packed in 1
|
||||
# uint8, divide by 2
|
||||
weight = weight.view(num_experts, -1,
|
||||
intermediate_size // 2).contiguous()
|
||||
if use_ep:
|
||||
narrow_weight = weight[ep_rank_start:ep_rank_end, ...]
|
||||
else:
|
||||
narrow_weight = weight[...,
|
||||
tp_rank_start // 2:tp_rank_end // 2]
|
||||
|
||||
param = params_dict[new_name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param,
|
||||
narrow_weight,
|
||||
weight_name=new_name,
|
||||
shard_id=None,
|
||||
expert_id=None)
|
||||
loaded_params.add(new_name)
|
||||
|
||||
elif "gate_up_proj_scales" in name:
|
||||
# Handle MLP gate and up projection weights scale
|
||||
new_name = name.replace("gate_up_proj_scales",
|
||||
"w13_weight_scale")
|
||||
if use_ep:
|
||||
narrow_weight = weight[ep_rank_start:ep_rank_end, ...]
|
||||
else:
|
||||
narrow_weight = weight[:,
|
||||
2 * tp_rank_start:2 * tp_rank_end,
|
||||
...]
|
||||
|
||||
param = params_dict[new_name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param,
|
||||
narrow_weight,
|
||||
weight_name=new_name,
|
||||
shard_id=None,
|
||||
expert_id=None)
|
||||
loaded_params.add(new_name)
|
||||
|
||||
elif "down_proj_scales" in name:
|
||||
# Handle MLP down projection weights
|
||||
new_name = name.replace("down_proj_scales", "w2_weight_scale")
|
||||
if use_ep:
|
||||
narrow_weight = weight[ep_rank_start:ep_rank_end, ...]
|
||||
else:
|
||||
narrow_weight = weight[..., tp_rank_start //
|
||||
mxfp4_block:tp_rank_end //
|
||||
mxfp4_block]
|
||||
|
||||
param = params_dict[new_name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param,
|
||||
narrow_weight,
|
||||
weight_name=new_name,
|
||||
shard_id=None,
|
||||
expert_id=None)
|
||||
loaded_params.add(new_name)
|
||||
elif "gate_up_proj_bias" in name:
|
||||
# Handle MLP gate and up projection biases
|
||||
new_name = name.replace("gate_up_proj_bias", "w13_bias")
|
||||
|
||||
# Extract gate and up projection bias parts
|
||||
if use_ep:
|
||||
narrow_weight = weight[ep_rank_start:ep_rank_end, ...]
|
||||
else:
|
||||
narrow_weight = weight[:,
|
||||
2 * tp_rank_start:2 * tp_rank_end]
|
||||
|
||||
param = params_dict[new_name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param,
|
||||
narrow_weight,
|
||||
weight_name=new_name,
|
||||
shard_id=None,
|
||||
expert_id=None)
|
||||
loaded_params.add(new_name)
|
||||
|
||||
elif "down_proj_bias" in name:
|
||||
# Handle MLP down projection bias
|
||||
new_name = name.replace("down_proj_bias", "w2_bias")
|
||||
param = params_dict[new_name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
if use_ep:
|
||||
weight = weight[ep_rank_start:ep_rank_end, ...]
|
||||
else:
|
||||
# (only load on rank 0 to avoid duplication)
|
||||
if tp_rank != 0:
|
||||
weight.zero_()
|
||||
weight_loader(param,
|
||||
weight,
|
||||
weight_name=new_name,
|
||||
shard_id=None,
|
||||
expert_id=None)
|
||||
loaded_params.add(new_name)
|
||||
elif "sinks" in name:
|
||||
# Handle attention sinks (distributed across ranks)
|
||||
name = name.replace("self_attn", "attn")
|
||||
param = params_dict[name]
|
||||
narrow_weight = weight.narrow(0, head_start, heads_per_rank)
|
||||
param.data.copy_(narrow_weight)
|
||||
loaded_params.add(name)
|
||||
elif "q_proj" in name or "k_proj" in name or "v_proj" in name:
|
||||
shard_id = ("q" if "q_proj" in name else
|
||||
"k" if "k_proj" in name else "v")
|
||||
name = name.replace("self_attn", "attn")
|
||||
param_name = name.replace(f"{shard_id}_proj", "qkv")
|
||||
param = params_dict[param_name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, weight, loaded_shard_id=shard_id)
|
||||
loaded_params.add(param_name)
|
||||
else:
|
||||
# Handle all other weights with potential renaming
|
||||
renamed_name = maybe_rename(name)
|
||||
if renamed_name not in params_dict:
|
||||
continue
|
||||
param = params_dict[renamed_name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, weight)
|
||||
loaded_params.add(renamed_name)
|
||||
|
||||
return loaded_params
|
||||
|
||||
def _load_weights_other(
|
||||
self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
|
||||
rename_mapping = {
|
||||
"self_attn": "attn",
|
||||
"input_layernorm.weight": "attn.norm.weight",
|
||||
"post_attention_layernorm.weight": "mlp.norm.weight",
|
||||
"embed_tokens": "embedding",
|
||||
}
|
||||
|
||||
def maybe_rename(name: str) -> str:
|
||||
for remap_name, new_name in rename_mapping.items():
|
||||
if remap_name in name:
|
||||
return name.replace(remap_name, new_name)
|
||||
return name
|
||||
|
||||
params_dict = dict(self.named_parameters())
|
||||
loaded_params: set[str] = set()
|
||||
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
intermediate_size = self.model_config.intermediate_size
|
||||
|
||||
per_rank_intermediate_size = cdiv(intermediate_size, tp_size)
|
||||
# Calculate common slicing bounds for current rank
|
||||
tp_rank_start = tp_rank * per_rank_intermediate_size
|
||||
tp_rank_end = min((tp_rank + 1) * per_rank_intermediate_size,
|
||||
intermediate_size)
|
||||
|
||||
# Attention heads per rank
|
||||
heads_per_rank = self.model_config.num_attention_heads // tp_size
|
||||
head_start = tp_rank * heads_per_rank
|
||||
|
||||
use_ep = self.vllm_config.parallel_config.enable_expert_parallel
|
||||
ep_size = get_ep_group().world_size
|
||||
ep_rank = get_ep_group().rank
|
||||
num_experts = self.model_config.num_local_experts
|
||||
experts_per_rank = num_experts // ep_size
|
||||
ep_rank_start = ep_rank * experts_per_rank
|
||||
ep_rank_end = (ep_rank + 1) * experts_per_rank
|
||||
|
||||
for name, weight in weights:
|
||||
if ".experts.gate_up_proj" in name and "bias" not in name:
|
||||
# Handle MLP gate and up projection weights
|
||||
new_name = name.replace(".experts.gate_up_proj",
|
||||
".experts.w13_weight")
|
||||
|
||||
# Extract gate and up projection parts
|
||||
# since the weight is shuffled, we can slice directly
|
||||
if use_ep:
|
||||
narrow_weight = weight[ep_rank_start:ep_rank_end, ...]
|
||||
else:
|
||||
narrow_weight = weight[:, :,
|
||||
2 * tp_rank_start:2 * tp_rank_end]
|
||||
|
||||
narrow_weight = narrow_weight.permute(0, 2, 1).contiguous()
|
||||
param = params_dict[new_name]
|
||||
|
||||
param.copy_(narrow_weight)
|
||||
loaded_params.add(new_name)
|
||||
|
||||
elif ".experts.down_proj" in name and "bias" not in name:
|
||||
# Handle MLP down projection weights
|
||||
new_name = name.replace(".experts.down_proj",
|
||||
".experts.w2_weight")
|
||||
|
||||
if use_ep:
|
||||
narrow_weight = weight[ep_rank_start:ep_rank_end, ...]
|
||||
else:
|
||||
narrow_weight = weight[:, tp_rank_start:tp_rank_end, :]
|
||||
narrow_weight = narrow_weight.permute(0, 2, 1).contiguous()
|
||||
param = params_dict[new_name]
|
||||
|
||||
param.copy_(narrow_weight)
|
||||
loaded_params.add(new_name)
|
||||
|
||||
elif "gate_up_proj_bias" in name:
|
||||
# Handle MLP gate and up projection biases
|
||||
new_name = name.replace("gate_up_proj_bias", "w13_bias")
|
||||
|
||||
# Extract gate and up projection bias parts
|
||||
if use_ep:
|
||||
narrow_weight = weight[ep_rank_start:ep_rank_end, ...]
|
||||
else:
|
||||
narrow_weight = weight[:,
|
||||
2 * tp_rank_start:2 * tp_rank_end]
|
||||
|
||||
param = params_dict[new_name]
|
||||
|
||||
param.copy_(narrow_weight)
|
||||
loaded_params.add(new_name)
|
||||
|
||||
elif "down_proj_bias" in name:
|
||||
# Handle MLP down projection bias
|
||||
new_name = name.replace("down_proj_bias", "w2_bias")
|
||||
|
||||
if use_ep:
|
||||
weight = weight[ep_rank_start:ep_rank_end, ...]
|
||||
else:
|
||||
# (only load on rank 0 to avoid duplication)
|
||||
if tp_rank != 0:
|
||||
weight.zero_()
|
||||
param = params_dict[new_name]
|
||||
param.copy_(weight)
|
||||
loaded_params.add(new_name)
|
||||
elif "sinks" in name:
|
||||
# Handle attention sinks (distributed across ranks)
|
||||
name = name.replace("self_attn", "attn")
|
||||
param = params_dict[name]
|
||||
narrow_weight = weight.narrow(0, head_start, heads_per_rank)
|
||||
param.data.copy_(narrow_weight)
|
||||
loaded_params.add(name)
|
||||
elif "q_proj" in name or "k_proj" in name or "v_proj" in name:
|
||||
shard_id = ("q" if "q_proj" in name else
|
||||
"k" if "k_proj" in name else "v")
|
||||
name = name.replace("self_attn", "attn")
|
||||
param_name = name.replace(f"{shard_id}_proj", "qkv")
|
||||
param = params_dict[param_name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, weight, loaded_shard_id=shard_id)
|
||||
loaded_params.add(param_name)
|
||||
else:
|
||||
# Handle all other weights with potential renaming
|
||||
|
||||
renamed_name = maybe_rename(name)
|
||||
if renamed_name not in params_dict:
|
||||
continue
|
||||
param = params_dict[renamed_name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, weight)
|
||||
loaded_params.add(renamed_name)
|
||||
|
||||
return loaded_params
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str,
|
||||
torch.Tensor]]) -> set[str]:
|
||||
quant_method = (self.model_config.quantization_config['quant_method']
|
||||
if hasattr(self.model_config, "quantization_config")
|
||||
else None)
|
||||
if quant_method == "mxfp4":
|
||||
return self._load_weights_mxfp4(weights)
|
||||
else:
|
||||
return self._load_weights_other(weights)
|
||||
@@ -61,6 +61,7 @@ _TEXT_GENERATION_MODELS = {
|
||||
"Gemma3ForCausalLM": ("gemma3", "Gemma3ForCausalLM"),
|
||||
"GlmForCausalLM": ("glm", "GlmForCausalLM"),
|
||||
"Glm4ForCausalLM": ("glm4", "Glm4ForCausalLM"),
|
||||
"GptOssForCausalLM": ("gpt_oss", "GptOssForCausalLM"),
|
||||
"GPT2LMHeadModel": ("gpt2", "GPT2LMHeadModel"),
|
||||
"GPTBigCodeForCausalLM": ("gpt_bigcode", "GPTBigCodeForCausalLM"),
|
||||
"GPTJForCausalLM": ("gpt_j", "GPTJForCausalLM"),
|
||||
|
||||
@@ -126,7 +126,8 @@ def use_rocm_custom_paged_attention(
|
||||
max_seq_len: int,
|
||||
sliding_window: int,
|
||||
kv_cache_dtype: str,
|
||||
alibi_slopes: Optional[torch.Tensor] = None) -> bool:
|
||||
alibi_slopes: Optional[torch.Tensor] = None,
|
||||
sinks: Optional[torch.Tensor] = None) -> bool:
|
||||
|
||||
GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName
|
||||
ON_GFX9 = any(arch in GPU_ARCH for arch in ["gfx90a", "gfx942", "gfx950"])
|
||||
@@ -143,7 +144,7 @@ def use_rocm_custom_paged_attention(
|
||||
and (gqa_ratio >= 1 and gqa_ratio <= 16)
|
||||
and max_seq_len <= 32768 and (envs.VLLM_ROCM_CUSTOM_PAGED_ATTN)
|
||||
and not (envs.VLLM_ROCM_USE_AITER_PAGED_ATTN
|
||||
and envs.VLLM_ROCM_USE_AITER))
|
||||
and envs.VLLM_ROCM_USE_AITER) and sinks is None)
|
||||
|
||||
else:
|
||||
return (ON_GFX11_GFX12 and (not envs.VLLM_USE_V1 or sliding_window == 0
|
||||
@@ -153,7 +154,7 @@ def use_rocm_custom_paged_attention(
|
||||
and (gqa_ratio >= 3 and gqa_ratio <= 16)
|
||||
and max_seq_len <= 32768 and alibi_slopes is None
|
||||
and kv_cache_dtype == "auto"
|
||||
and envs.VLLM_ROCM_CUSTOM_PAGED_ATTN)
|
||||
and envs.VLLM_ROCM_CUSTOM_PAGED_ATTN and sinks is None)
|
||||
|
||||
|
||||
class RocmPlatform(Platform):
|
||||
|
||||
@@ -73,7 +73,7 @@ IMAGE_TOKEN = "<image>"
|
||||
IMAGE_ATOM_ID = -300
|
||||
IMAGE_INDICATOR_IDS = [-301, -302, -303, -304, -305]
|
||||
|
||||
AutoConfig.register("aimv2", AIMv2Config)
|
||||
AutoConfig.register("aimv2", AIMv2Config, exist_ok=True)
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------
|
||||
|
||||
@@ -90,6 +90,7 @@ class TritonAttentionImpl(AttentionImpl):
|
||||
attn_type: AttentionType = AttentionType.DECODER,
|
||||
kv_sharing_target_layer_name: Optional[int] = None,
|
||||
use_irope: bool = False,
|
||||
sinks: Optional[torch.Tensor] = None,
|
||||
) -> None:
|
||||
if blocksparse_params is not None:
|
||||
raise ValueError(
|
||||
@@ -132,6 +133,13 @@ class TritonAttentionImpl(AttentionImpl):
|
||||
self.fp8_dtype = current_platform.fp8_dtype()
|
||||
self.force_prefill_decode_attn = \
|
||||
envs.VLLM_V1_USE_PREFILL_DECODE_ATTENTION
|
||||
|
||||
self.sinks = sinks
|
||||
if sinks is not None:
|
||||
assert sinks.shape[0] == num_heads, (
|
||||
"Sinks must have the same number of heads as the number of "
|
||||
f"heads in the layer. Sinks shape: {sinks.shape}, "
|
||||
f"num_heads: {num_heads}.")
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -257,7 +265,8 @@ class TritonAttentionImpl(AttentionImpl):
|
||||
v_scale=layer._v_scale,
|
||||
alibi_slopes=self.alibi_slopes,
|
||||
sliding_window=self.sliding_window[0],
|
||||
sm_scale=self.scale)
|
||||
sm_scale=self.scale,
|
||||
sinks=self.sinks)
|
||||
|
||||
else:
|
||||
descale_shape = (cu_seqlens_q.shape[0] - 1, key.shape[1])
|
||||
@@ -280,6 +289,7 @@ class TritonAttentionImpl(AttentionImpl):
|
||||
q_descale=None, # Not supported
|
||||
k_descale=layer._k_scale.expand(descale_shape),
|
||||
v_descale=layer._v_scale.expand(descale_shape),
|
||||
sinks=self.sinks,
|
||||
)
|
||||
|
||||
return output
|
||||
|
||||
Reference in New Issue
Block a user