[gpt-oss] Add gpt-oss bf16 support
This commit is contained in:
@@ -28,6 +28,7 @@ def kernel_paged_attention_2d(
|
|||||||
query_ptr, # [num_tokens, num_query_heads, head_size]
|
query_ptr, # [num_tokens, num_query_heads, head_size]
|
||||||
key_cache_ptr, # [num_blks, num_kv_heads, head_size // x, blk_size, x]
|
key_cache_ptr, # [num_blks, num_kv_heads, head_size // x, blk_size, x]
|
||||||
value_cache_ptr, # [num_blks, num_kv_heads, head_size, blk_size]
|
value_cache_ptr, # [num_blks, num_kv_heads, head_size, blk_size]
|
||||||
|
sink_ptr, # [num_query_heads]
|
||||||
block_tables_ptr, # [num_seqs, max_num_blocks_per_seq]
|
block_tables_ptr, # [num_seqs, max_num_blocks_per_seq]
|
||||||
seq_lens_ptr, # [num_seqs]
|
seq_lens_ptr, # [num_seqs]
|
||||||
alibi_slopes_ptr, # [num_query_heads]
|
alibi_slopes_ptr, # [num_query_heads]
|
||||||
@@ -59,6 +60,7 @@ def kernel_paged_attention_2d(
|
|||||||
stride_v_cache_3: tl.int64, # int
|
stride_v_cache_3: tl.int64, # int
|
||||||
filter_by_query_len: tl.constexpr, # bool
|
filter_by_query_len: tl.constexpr, # bool
|
||||||
query_start_len_ptr, # [num_seqs+1]
|
query_start_len_ptr, # [num_seqs+1]
|
||||||
|
USE_SINKS: tl.constexpr, # bool
|
||||||
):
|
):
|
||||||
seq_idx = tl.program_id(0)
|
seq_idx = tl.program_id(0)
|
||||||
kv_head_idx = tl.program_id(1)
|
kv_head_idx = tl.program_id(1)
|
||||||
@@ -95,7 +97,18 @@ def kernel_paged_attention_2d(
|
|||||||
|
|
||||||
block_table_offset = seq_idx * block_table_stride
|
block_table_offset = seq_idx * block_table_stride
|
||||||
|
|
||||||
M = tl.full([num_queries_per_kv_padded], float("-inf"), dtype=tl.float32)
|
if not USE_SINKS:
|
||||||
|
M = tl.full([num_queries_per_kv_padded],
|
||||||
|
float("-inf"),
|
||||||
|
dtype=tl.float32)
|
||||||
|
else:
|
||||||
|
M = tl.load(
|
||||||
|
sink_ptr + query_head_idx,
|
||||||
|
mask=head_mask,
|
||||||
|
other=float("-inf"),
|
||||||
|
).to(dtype=tl.float32)
|
||||||
|
# M = tl.full([num_queries_per_kv_padded], float("-inf"), dtype=tl.float32)
|
||||||
|
|
||||||
L = tl.full([num_queries_per_kv_padded], 1.0, dtype=tl.float32)
|
L = tl.full([num_queries_per_kv_padded], 1.0, dtype=tl.float32)
|
||||||
acc = tl.zeros([num_queries_per_kv_padded, HEAD_SIZE_PADDED],
|
acc = tl.zeros([num_queries_per_kv_padded, HEAD_SIZE_PADDED],
|
||||||
dtype=tl.float32)
|
dtype=tl.float32)
|
||||||
@@ -223,6 +236,8 @@ def chunked_prefill_paged_decode(
|
|||||||
alibi_slopes=None,
|
alibi_slopes=None,
|
||||||
sliding_window=None,
|
sliding_window=None,
|
||||||
sm_scale=None,
|
sm_scale=None,
|
||||||
|
# Optional tensor for sinks
|
||||||
|
sinks=None,
|
||||||
):
|
):
|
||||||
|
|
||||||
if sm_scale is None:
|
if sm_scale is None:
|
||||||
@@ -253,6 +268,7 @@ def chunked_prefill_paged_decode(
|
|||||||
sliding_window=sliding_window,
|
sliding_window=sliding_window,
|
||||||
sm_scale=sm_scale,
|
sm_scale=sm_scale,
|
||||||
skip_decode=True,
|
skip_decode=True,
|
||||||
|
sinks=sinks,
|
||||||
)
|
)
|
||||||
|
|
||||||
block_size = value_cache.shape[3]
|
block_size = value_cache.shape[3]
|
||||||
@@ -285,7 +301,7 @@ def chunked_prefill_paged_decode(
|
|||||||
block_size,
|
block_size,
|
||||||
num_queries_per_kv,
|
num_queries_per_kv,
|
||||||
max_seq_len, sliding_window,
|
max_seq_len, sliding_window,
|
||||||
kv_cache_dtype, alibi_slopes)
|
kv_cache_dtype, alibi_slopes, sinks,)
|
||||||
if use_custom:
|
if use_custom:
|
||||||
_PARTITION_SIZE_ROCM = 256
|
_PARTITION_SIZE_ROCM = 256
|
||||||
max_num_partitions = ((max_seq_len + _PARTITION_SIZE_ROCM - 1) //
|
max_num_partitions = ((max_seq_len + _PARTITION_SIZE_ROCM - 1) //
|
||||||
@@ -334,6 +350,7 @@ def chunked_prefill_paged_decode(
|
|||||||
query_ptr=query,
|
query_ptr=query,
|
||||||
key_cache_ptr=key_cache,
|
key_cache_ptr=key_cache,
|
||||||
value_cache_ptr=value_cache,
|
value_cache_ptr=value_cache,
|
||||||
|
sink_ptr=sinks,
|
||||||
block_tables_ptr=block_table,
|
block_tables_ptr=block_table,
|
||||||
seq_lens_ptr=seq_lens,
|
seq_lens_ptr=seq_lens,
|
||||||
alibi_slopes_ptr=alibi_slopes,
|
alibi_slopes_ptr=alibi_slopes,
|
||||||
@@ -365,4 +382,5 @@ def chunked_prefill_paged_decode(
|
|||||||
stride_v_cache_3=value_cache.stride(3),
|
stride_v_cache_3=value_cache.stride(3),
|
||||||
filter_by_query_len=True,
|
filter_by_query_len=True,
|
||||||
query_start_len_ptr=query_start_loc,
|
query_start_len_ptr=query_start_loc,
|
||||||
|
USE_SINKS=sinks is not None,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -34,6 +34,7 @@ def kernel_unified_attention_2d(
|
|||||||
query_ptr, # [num_tokens, num_query_heads, head_size]
|
query_ptr, # [num_tokens, num_query_heads, head_size]
|
||||||
key_cache_ptr, # [num_blks, blk_size, num_kv_heads, head_size]
|
key_cache_ptr, # [num_blks, blk_size, num_kv_heads, head_size]
|
||||||
value_cache_ptr, # [num_blks, blk_size, num_kv_heads, head_size]
|
value_cache_ptr, # [num_blks, blk_size, num_kv_heads, head_size]
|
||||||
|
sink_ptr, # [num_query_heads]
|
||||||
block_tables_ptr, # [num_seqs, max_num_blocks_per_seq]
|
block_tables_ptr, # [num_seqs, max_num_blocks_per_seq]
|
||||||
seq_lens_ptr, # [num_seqs]
|
seq_lens_ptr, # [num_seqs]
|
||||||
alibi_slopes_ptr, # [num_query_heads]
|
alibi_slopes_ptr, # [num_query_heads]
|
||||||
@@ -53,6 +54,7 @@ def kernel_unified_attention_2d(
|
|||||||
HEAD_SIZE_PADDED: tl.constexpr, # int, must be power of 2
|
HEAD_SIZE_PADDED: tl.constexpr, # int, must be power of 2
|
||||||
USE_ALIBI_SLOPES: tl.constexpr, # bool
|
USE_ALIBI_SLOPES: tl.constexpr, # bool
|
||||||
USE_SOFTCAP: tl.constexpr, # bool
|
USE_SOFTCAP: tl.constexpr, # bool
|
||||||
|
USE_SINKS: tl.constexpr, # bool
|
||||||
SLIDING_WINDOW: tl.constexpr, # int
|
SLIDING_WINDOW: tl.constexpr, # int
|
||||||
stride_k_cache_0: tl.int64, # int
|
stride_k_cache_0: tl.int64, # int
|
||||||
stride_k_cache_1: tl.int64, # int
|
stride_k_cache_1: tl.int64, # int
|
||||||
@@ -119,7 +121,16 @@ def kernel_unified_attention_2d(
|
|||||||
|
|
||||||
block_table_offset = seq_idx * block_table_stride
|
block_table_offset = seq_idx * block_table_stride
|
||||||
|
|
||||||
M = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32)
|
if not USE_SINKS:
|
||||||
|
M = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32)
|
||||||
|
else:
|
||||||
|
M = tl.load(
|
||||||
|
sink_ptr + query_offset_1,
|
||||||
|
mask=query_mask_1,
|
||||||
|
other=float("-inf"),
|
||||||
|
).to(dtype=tl.float32)
|
||||||
|
# M = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32)
|
||||||
|
|
||||||
L = tl.full([BLOCK_M], 1.0, dtype=tl.float32)
|
L = tl.full([BLOCK_M], 1.0, dtype=tl.float32)
|
||||||
acc = tl.zeros([BLOCK_M, HEAD_SIZE_PADDED], dtype=tl.float32)
|
acc = tl.zeros([BLOCK_M, HEAD_SIZE_PADDED], dtype=tl.float32)
|
||||||
|
|
||||||
@@ -260,6 +271,8 @@ def unified_attention(
|
|||||||
k_descale,
|
k_descale,
|
||||||
v_descale,
|
v_descale,
|
||||||
alibi_slopes=None,
|
alibi_slopes=None,
|
||||||
|
# Optional tensor for sinks
|
||||||
|
sinks=None,
|
||||||
):
|
):
|
||||||
assert causal, "Only causal attention is supported"
|
assert causal, "Only causal attention is supported"
|
||||||
assert q_descale is None, "Q scales not supported"
|
assert q_descale is None, "Q scales not supported"
|
||||||
@@ -268,6 +281,10 @@ def unified_attention(
|
|||||||
assert q.element_size() >= 2 or block_size >= 32, \
|
assert q.element_size() >= 2 or block_size >= 32, \
|
||||||
"Block size must be at least 32 for fp8"
|
"Block size must be at least 32 for fp8"
|
||||||
|
|
||||||
|
if sinks is not None:
|
||||||
|
assert sinks.shape[0] == q.shape[1], \
|
||||||
|
"Sinks must be num_query_heads size"
|
||||||
|
|
||||||
use_alibi_slopes = alibi_slopes is not None
|
use_alibi_slopes = alibi_slopes is not None
|
||||||
|
|
||||||
block_size = v.shape[1]
|
block_size = v.shape[1]
|
||||||
@@ -299,6 +316,7 @@ def unified_attention(
|
|||||||
query_ptr=q,
|
query_ptr=q,
|
||||||
key_cache_ptr=k,
|
key_cache_ptr=k,
|
||||||
value_cache_ptr=v,
|
value_cache_ptr=v,
|
||||||
|
sink_ptr=sinks,
|
||||||
block_tables_ptr=block_table,
|
block_tables_ptr=block_table,
|
||||||
seq_lens_ptr=seqused_k,
|
seq_lens_ptr=seqused_k,
|
||||||
alibi_slopes_ptr=alibi_slopes,
|
alibi_slopes_ptr=alibi_slopes,
|
||||||
@@ -318,6 +336,7 @@ def unified_attention(
|
|||||||
HEAD_SIZE_PADDED=triton.next_power_of_2(head_size),
|
HEAD_SIZE_PADDED=triton.next_power_of_2(head_size),
|
||||||
USE_ALIBI_SLOPES=use_alibi_slopes,
|
USE_ALIBI_SLOPES=use_alibi_slopes,
|
||||||
USE_SOFTCAP=(softcap > 0),
|
USE_SOFTCAP=(softcap > 0),
|
||||||
|
USE_SINKS=(sinks is not None),
|
||||||
SLIDING_WINDOW=(1 + window_size[0]),
|
SLIDING_WINDOW=(1 + window_size[0]),
|
||||||
stride_k_cache_0=k.stride(0),
|
stride_k_cache_0=k.stride(0),
|
||||||
stride_k_cache_1=k.stride(1),
|
stride_k_cache_1=k.stride(1),
|
||||||
|
|||||||
@@ -275,6 +275,7 @@ def fused_moe_kernel(
|
|||||||
a_ptr,
|
a_ptr,
|
||||||
b_ptr,
|
b_ptr,
|
||||||
c_ptr,
|
c_ptr,
|
||||||
|
b_bias_ptr,
|
||||||
a_scale_ptr,
|
a_scale_ptr,
|
||||||
b_scale_ptr,
|
b_scale_ptr,
|
||||||
topk_weights_ptr,
|
topk_weights_ptr,
|
||||||
@@ -303,6 +304,8 @@ def fused_moe_kernel(
|
|||||||
stride_bse,
|
stride_bse,
|
||||||
stride_bsk,
|
stride_bsk,
|
||||||
stride_bsn,
|
stride_bsn,
|
||||||
|
stride_bbe, # bias expert stride
|
||||||
|
stride_bbn, # bias N stride
|
||||||
# Block size for block-wise quantization
|
# Block size for block-wise quantization
|
||||||
group_n: tl.constexpr,
|
group_n: tl.constexpr,
|
||||||
group_k: tl.constexpr,
|
group_k: tl.constexpr,
|
||||||
@@ -321,6 +324,7 @@ def fused_moe_kernel(
|
|||||||
use_int8_w8a8: tl.constexpr,
|
use_int8_w8a8: tl.constexpr,
|
||||||
use_int8_w8a16: tl.constexpr,
|
use_int8_w8a16: tl.constexpr,
|
||||||
per_channel_quant: tl.constexpr,
|
per_channel_quant: tl.constexpr,
|
||||||
|
HAS_BIAS: tl.constexpr,
|
||||||
UPGRADE: tl.constexpr,
|
UPGRADE: tl.constexpr,
|
||||||
UPGRADE_A_OFFS: tl.constexpr,
|
UPGRADE_A_OFFS: tl.constexpr,
|
||||||
UPGRADE_B_OFFS: tl.constexpr,
|
UPGRADE_B_OFFS: tl.constexpr,
|
||||||
@@ -447,6 +451,10 @@ def fused_moe_kernel(
|
|||||||
else:
|
else:
|
||||||
a_scale = tl.load(a_scale_ptr)
|
a_scale = tl.load(a_scale_ptr)
|
||||||
b_scale = tl.load(b_scale_ptr + off_experts)
|
b_scale = tl.load(b_scale_ptr + off_experts)
|
||||||
|
if HAS_BIAS:
|
||||||
|
# bias shape: [num_experts, N]
|
||||||
|
bias_ptrs = b_bias_ptr + off_experts * stride_bbe + offs_bn * stride_bbn
|
||||||
|
bias = tl.load(bias_ptrs, mask=(offs_bn < N), other=0.0)
|
||||||
|
|
||||||
# -----------------------------------------------------------
|
# -----------------------------------------------------------
|
||||||
# Iterate to compute a block of the C matrix.
|
# Iterate to compute a block of the C matrix.
|
||||||
@@ -494,7 +502,8 @@ def fused_moe_kernel(
|
|||||||
# Advance the ptrs to the next K block.
|
# Advance the ptrs to the next K block.
|
||||||
a_ptrs += BLOCK_SIZE_K * stride_ak * SPLIT_K
|
a_ptrs += BLOCK_SIZE_K * stride_ak * SPLIT_K
|
||||||
b_ptrs += BLOCK_SIZE_K * stride_bk * SPLIT_K
|
b_ptrs += BLOCK_SIZE_K * stride_bk * SPLIT_K
|
||||||
|
if HAS_BIAS:
|
||||||
|
accumulator = accumulator + bias[None, :]
|
||||||
if MUL_ROUTED_WEIGHT:
|
if MUL_ROUTED_WEIGHT:
|
||||||
moe_weight = tl.load(topk_weights_ptr + offs_token,
|
moe_weight = tl.load(topk_weights_ptr + offs_token,
|
||||||
mask=token_mask,
|
mask=token_mask,
|
||||||
@@ -548,7 +557,8 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
|
|||||||
use_int4_w4a16: bool,
|
use_int4_w4a16: bool,
|
||||||
orig_acc_dtype: torch.dtype,
|
orig_acc_dtype: torch.dtype,
|
||||||
per_channel_quant: bool,
|
per_channel_quant: bool,
|
||||||
block_shape: Optional[list[int]] = None) -> None:
|
block_shape: Optional[list[int]] = None,
|
||||||
|
B_bias: Optional[torch.Tensor] = None) -> None:
|
||||||
assert topk_weights is not None or not mul_routed_weight
|
assert topk_weights is not None or not mul_routed_weight
|
||||||
assert topk_weights is None or topk_weights.stride(1) == 1
|
assert topk_weights is None or topk_weights.stride(1) == 1
|
||||||
assert sorted_token_ids.stride(0) == 1
|
assert sorted_token_ids.stride(0) == 1
|
||||||
@@ -580,7 +590,7 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
|
|||||||
A.shape[0] * top_k * config['BLOCK_SIZE_M'])
|
A.shape[0] * top_k * config['BLOCK_SIZE_M'])
|
||||||
grid = lambda META: (triton.cdiv(EM, META['BLOCK_SIZE_M']) * triton.cdiv(
|
grid = lambda META: (triton.cdiv(EM, META['BLOCK_SIZE_M']) * triton.cdiv(
|
||||||
B.shape[1], META['BLOCK_SIZE_N']), META['SPLIT_K'])
|
B.shape[1], META['BLOCK_SIZE_N']), META['SPLIT_K'])
|
||||||
|
HAS_BIAS = B_bias is not None
|
||||||
if (use_int8_w8a16 or use_int4_w4a16) and \
|
if (use_int8_w8a16 or use_int4_w4a16) and \
|
||||||
block_shape is not None and block_shape[1] > 0:
|
block_shape is not None and block_shape[1] > 0:
|
||||||
assert B_scale is not None and B_scale.ndim == 3
|
assert B_scale is not None and B_scale.ndim == 3
|
||||||
@@ -592,19 +602,19 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
|
|||||||
num_experts=B.shape[0],
|
num_experts=B.shape[0],
|
||||||
bit=4 if use_int4_w4a16 else 8)
|
bit=4 if use_int4_w4a16 else 8)
|
||||||
# TODO: missing config for BLOCK_SIZE_K
|
# TODO: missing config for BLOCK_SIZE_K
|
||||||
# config = config.copy()
|
config = config.copy()
|
||||||
# config.update(
|
config.update(
|
||||||
# get_moe_wna16_block_config(config=config,
|
get_moe_wna16_block_config(config=config,
|
||||||
# use_moe_wna16_cuda=use_moe_wna16_cuda,
|
use_moe_wna16_cuda=use_moe_wna16_cuda,
|
||||||
# num_valid_tokens=num_tokens,
|
num_valid_tokens=num_tokens,
|
||||||
# size_k=A.shape[1],
|
size_k=A.shape[1],
|
||||||
# size_n=B.shape[1],
|
size_n=B.shape[1],
|
||||||
# num_experts=B.shape[1],
|
num_experts=B.shape[1],
|
||||||
# group_size=block_shape[1],
|
group_size=block_shape[1],
|
||||||
# real_top_k=top_k,
|
real_top_k=top_k,
|
||||||
# block_size_m=config["BLOCK_SIZE_M"]))
|
block_size_m=config["BLOCK_SIZE_M"]))
|
||||||
|
|
||||||
if False and use_moe_wna16_cuda:
|
if use_moe_wna16_cuda:
|
||||||
bit = 4 if use_int4_w4a16 else 8
|
bit = 4 if use_int4_w4a16 else 8
|
||||||
ops.moe_wna16_gemm(A, C, B, B_scale, B_zp,
|
ops.moe_wna16_gemm(A, C, B, B_scale, B_zp,
|
||||||
topk_weights if mul_routed_weight else None,
|
topk_weights if mul_routed_weight else None,
|
||||||
@@ -661,6 +671,7 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
|
|||||||
A,
|
A,
|
||||||
B,
|
B,
|
||||||
C,
|
C,
|
||||||
|
B_bias,
|
||||||
A_scale,
|
A_scale,
|
||||||
B_scale,
|
B_scale,
|
||||||
topk_weights,
|
topk_weights,
|
||||||
@@ -689,6 +700,8 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
|
|||||||
if B_scale is not None and B_scale.ndim == 3 else 0,
|
if B_scale is not None and B_scale.ndim == 3 else 0,
|
||||||
B_scale.stride(1)
|
B_scale.stride(1)
|
||||||
if B_scale is not None and B_scale.ndim >= 2 else 0,
|
if B_scale is not None and B_scale.ndim >= 2 else 0,
|
||||||
|
B_bias.stride(0) if B_bias is not None else 0,
|
||||||
|
B_bias.stride(1) if B_bias is not None else 0,
|
||||||
0 if block_shape is None else block_shape[0],
|
0 if block_shape is None else block_shape[0],
|
||||||
0 if block_shape is None else block_shape[1],
|
0 if block_shape is None else block_shape[1],
|
||||||
MUL_ROUTED_WEIGHT=mul_routed_weight,
|
MUL_ROUTED_WEIGHT=mul_routed_weight,
|
||||||
@@ -699,6 +712,7 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
|
|||||||
use_int8_w8a8=use_int8_w8a8,
|
use_int8_w8a8=use_int8_w8a8,
|
||||||
use_int8_w8a16=use_int8_w8a16,
|
use_int8_w8a16=use_int8_w8a16,
|
||||||
per_channel_quant=per_channel_quant,
|
per_channel_quant=per_channel_quant,
|
||||||
|
HAS_BIAS=HAS_BIAS,
|
||||||
BLOCK_SIZE_K=BLOCK_SIZE_K,
|
BLOCK_SIZE_K=BLOCK_SIZE_K,
|
||||||
FAST_F32_TO_BF16 = True,
|
FAST_F32_TO_BF16 = True,
|
||||||
**config,
|
**config,
|
||||||
@@ -1103,13 +1117,15 @@ def inplace_fused_experts(hidden_states: torch.Tensor,
|
|||||||
w2_zp: Optional[torch.Tensor] = None,
|
w2_zp: Optional[torch.Tensor] = None,
|
||||||
a1_scale: Optional[torch.Tensor] = None,
|
a1_scale: Optional[torch.Tensor] = None,
|
||||||
a2_scale: Optional[torch.Tensor] = None,
|
a2_scale: Optional[torch.Tensor] = None,
|
||||||
block_shape: Optional[List[int]] = None) -> None:
|
block_shape: Optional[List[int]] = None,
|
||||||
|
w1_bias: Optional[torch.Tensor] = None,
|
||||||
|
w2_bias: Optional[torch.Tensor] = None) -> None:
|
||||||
fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids, True,
|
fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids, True,
|
||||||
activation, apply_router_weight_on_input, use_fp8_w8a8,
|
activation, apply_router_weight_on_input, use_fp8_w8a8,
|
||||||
use_int8_w8a8, use_int8_w8a16, use_int4_w4a16,
|
use_int8_w8a8, use_int8_w8a16, use_int4_w4a16,
|
||||||
per_channel_quant, global_num_experts, expert_map,
|
per_channel_quant, global_num_experts, expert_map,
|
||||||
w1_scale, w2_scale, w1_zp, w2_zp, a1_scale, a2_scale,
|
w1_scale, w2_scale, w1_zp, w2_zp, a1_scale, a2_scale,
|
||||||
block_shape)
|
block_shape, w1_bias, w2_bias)
|
||||||
|
|
||||||
|
|
||||||
def inplace_fused_experts_fake(
|
def inplace_fused_experts_fake(
|
||||||
@@ -1133,7 +1149,9 @@ def inplace_fused_experts_fake(
|
|||||||
w2_zp: Optional[torch.Tensor] = None,
|
w2_zp: Optional[torch.Tensor] = None,
|
||||||
a1_scale: Optional[torch.Tensor] = None,
|
a1_scale: Optional[torch.Tensor] = None,
|
||||||
a2_scale: Optional[torch.Tensor] = None,
|
a2_scale: Optional[torch.Tensor] = None,
|
||||||
block_shape: Optional[List[int]] = None) -> None:
|
block_shape: Optional[List[int]] = None,
|
||||||
|
w1_bias: Optional[torch.Tensor] = None,
|
||||||
|
w2_bias: Optional[torch.Tensor] = None) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
@@ -1167,14 +1185,16 @@ def outplace_fused_experts(
|
|||||||
w2_zp: Optional[torch.Tensor] = None,
|
w2_zp: Optional[torch.Tensor] = None,
|
||||||
a1_scale: Optional[torch.Tensor] = None,
|
a1_scale: Optional[torch.Tensor] = None,
|
||||||
a2_scale: Optional[torch.Tensor] = None,
|
a2_scale: Optional[torch.Tensor] = None,
|
||||||
block_shape: Optional[List[int]] = None) -> torch.Tensor:
|
block_shape: Optional[List[int]] = None,
|
||||||
|
w1_bias: Optional[torch.Tensor] = None,
|
||||||
|
w2_bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||||
return fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids,
|
return fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids,
|
||||||
False, activation, apply_router_weight_on_input,
|
False, activation, apply_router_weight_on_input,
|
||||||
use_fp8_w8a8, use_int8_w8a8, use_int8_w8a16,
|
use_fp8_w8a8, use_int8_w8a8, use_int8_w8a16,
|
||||||
use_int4_w4a16, per_channel_quant,
|
use_int4_w4a16, per_channel_quant,
|
||||||
global_num_experts, expert_map, w1_scale,
|
global_num_experts, expert_map, w1_scale,
|
||||||
w2_scale, w1_zp, w2_zp, a1_scale, a2_scale,
|
w2_scale, w1_zp, w2_zp, a1_scale, a2_scale,
|
||||||
block_shape)
|
block_shape, w1_bias, w2_bias)
|
||||||
|
|
||||||
|
|
||||||
def outplace_fused_experts_fake(
|
def outplace_fused_experts_fake(
|
||||||
@@ -1197,7 +1217,9 @@ def outplace_fused_experts_fake(
|
|||||||
w2_zp: Optional[torch.Tensor] = None,
|
w2_zp: Optional[torch.Tensor] = None,
|
||||||
a1_scale: Optional[torch.Tensor] = None,
|
a1_scale: Optional[torch.Tensor] = None,
|
||||||
a2_scale: Optional[torch.Tensor] = None,
|
a2_scale: Optional[torch.Tensor] = None,
|
||||||
block_shape: Optional[List[int]] = None) -> torch.Tensor:
|
block_shape: Optional[List[int]] = None,
|
||||||
|
w1_bias: Optional[torch.Tensor] = None,
|
||||||
|
w2_bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||||
return torch.empty_like(hidden_states)
|
return torch.empty_like(hidden_states)
|
||||||
|
|
||||||
|
|
||||||
@@ -1248,7 +1270,9 @@ def fused_experts(hidden_states: torch.Tensor,
|
|||||||
a1_scale: Optional[torch.Tensor] = None,
|
a1_scale: Optional[torch.Tensor] = None,
|
||||||
a2_scale: Optional[torch.Tensor] = None,
|
a2_scale: Optional[torch.Tensor] = None,
|
||||||
block_shape: Optional[list[int]] = None,
|
block_shape: Optional[list[int]] = None,
|
||||||
allow_deep_gemm: bool = False) -> torch.Tensor:
|
allow_deep_gemm: bool = False,
|
||||||
|
w1_bias: Optional[torch.Tensor] = None,
|
||||||
|
w2_bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||||
# For now, disable DeepGemm for small N (<= 512) until better
|
# For now, disable DeepGemm for small N (<= 512) until better
|
||||||
# permute/unpermute ops are available.
|
# permute/unpermute ops are available.
|
||||||
N = w1.shape[1]
|
N = w1.shape[1]
|
||||||
@@ -1293,7 +1317,10 @@ def fused_experts(hidden_states: torch.Tensor,
|
|||||||
w2_zp=w2_zp,
|
w2_zp=w2_zp,
|
||||||
a1_scale=a1_scale,
|
a1_scale=a1_scale,
|
||||||
a2_scale=a2_scale,
|
a2_scale=a2_scale,
|
||||||
block_shape=block_shape)
|
block_shape=block_shape,
|
||||||
|
w1_bias=w1_bias,
|
||||||
|
w2_bias=w2_bias,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def fused_experts_impl(
|
def fused_experts_impl(
|
||||||
@@ -1319,6 +1346,8 @@ def fused_experts_impl(
|
|||||||
a1_scale: Optional[torch.Tensor] = None,
|
a1_scale: Optional[torch.Tensor] = None,
|
||||||
a2_scale: Optional[torch.Tensor] = None,
|
a2_scale: Optional[torch.Tensor] = None,
|
||||||
block_shape: Optional[list[int]] = None,
|
block_shape: Optional[list[int]] = None,
|
||||||
|
w1_bias: Optional[torch.Tensor] = None,
|
||||||
|
w2_bias: Optional[torch.Tensor] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
# Check constraints.
|
# Check constraints.
|
||||||
if use_int4_w4a16:
|
if use_int4_w4a16:
|
||||||
@@ -1498,7 +1527,19 @@ def fused_experts_impl(
|
|||||||
use_int4_w4a16=use_int4_w4a16,
|
use_int4_w4a16=use_int4_w4a16,
|
||||||
orig_acc_dtype=hidden_states.dtype,
|
orig_acc_dtype=hidden_states.dtype,
|
||||||
per_channel_quant=per_channel_quant,
|
per_channel_quant=per_channel_quant,
|
||||||
block_shape=block_shape)
|
block_shape=block_shape,
|
||||||
|
B_bias=w1_bias)
|
||||||
|
|
||||||
|
# TODO fused kernel
|
||||||
|
def swiglu_oai(gate_up):
|
||||||
|
alpha = 1.702
|
||||||
|
limit = 7.0
|
||||||
|
gate, up = gate_up[..., ::2], gate_up[..., 1::2]
|
||||||
|
gate = gate.clamp(min=None, max=limit)
|
||||||
|
up = up.clamp(min=-limit, max=limit)
|
||||||
|
glu = gate * torch.sigmoid(gate * alpha)
|
||||||
|
gated_output = (up + 1) * glu
|
||||||
|
return gated_output
|
||||||
|
|
||||||
if activation == "silu":
|
if activation == "silu":
|
||||||
torch.ops._C.silu_and_mul(intermediate_cache2,
|
torch.ops._C.silu_and_mul(intermediate_cache2,
|
||||||
@@ -1506,6 +1547,8 @@ def fused_experts_impl(
|
|||||||
elif activation == "gelu":
|
elif activation == "gelu":
|
||||||
torch.ops._C.gelu_and_mul(intermediate_cache2,
|
torch.ops._C.gelu_and_mul(intermediate_cache2,
|
||||||
intermediate_cache1.view(-1, N))
|
intermediate_cache1.view(-1, N))
|
||||||
|
elif activation == "swiglu_oai":
|
||||||
|
intermediate_cache2 = swiglu_oai(intermediate_cache1.view(-1, N))
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported FusedMoe activation: {activation}")
|
raise ValueError(f"Unsupported FusedMoe activation: {activation}")
|
||||||
|
|
||||||
@@ -1543,7 +1586,8 @@ def fused_experts_impl(
|
|||||||
use_int4_w4a16=use_int4_w4a16,
|
use_int4_w4a16=use_int4_w4a16,
|
||||||
orig_acc_dtype=hidden_states.dtype,
|
orig_acc_dtype=hidden_states.dtype,
|
||||||
per_channel_quant=per_channel_quant,
|
per_channel_quant=per_channel_quant,
|
||||||
block_shape=block_shape)
|
block_shape=block_shape,
|
||||||
|
B_bias=w2_bias)
|
||||||
|
|
||||||
ops.moe_sum(intermediate_cache3.view(*intermediate_cache3.shape),
|
ops.moe_sum(intermediate_cache3.view(*intermediate_cache3.shape),
|
||||||
out_hidden_states[begin_chunk_idx:end_chunk_idx])
|
out_hidden_states[begin_chunk_idx:end_chunk_idx])
|
||||||
@@ -1578,6 +1622,8 @@ def fused_moe(
|
|||||||
a1_scale: Optional[torch.Tensor] = None,
|
a1_scale: Optional[torch.Tensor] = None,
|
||||||
a2_scale: Optional[torch.Tensor] = None,
|
a2_scale: Optional[torch.Tensor] = None,
|
||||||
block_shape: Optional[list[int]] = None,
|
block_shape: Optional[list[int]] = None,
|
||||||
|
w1_bias: Optional[torch.Tensor] = None,
|
||||||
|
w2_bias: Optional[torch.Tensor] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
This function computes a Mixture of Experts (MoE) layer using two sets of
|
This function computes a Mixture of Experts (MoE) layer using two sets of
|
||||||
@@ -1661,7 +1707,9 @@ def fused_moe(
|
|||||||
w2_zp=w2_zp,
|
w2_zp=w2_zp,
|
||||||
a1_scale=a1_scale,
|
a1_scale=a1_scale,
|
||||||
a2_scale=a2_scale,
|
a2_scale=a2_scale,
|
||||||
block_shape=block_shape)
|
block_shape=block_shape,
|
||||||
|
w1_bias=w1_bias,
|
||||||
|
w2_bias=w2_bias)
|
||||||
|
|
||||||
|
|
||||||
class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||||
@@ -1805,7 +1853,9 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
|||||||
use_int8_w8a16=self.use_int8_w8a16,
|
use_int8_w8a16=self.use_int8_w8a16,
|
||||||
use_int4_w4a16=self.use_int4_w4a16,
|
use_int4_w4a16=self.use_int4_w4a16,
|
||||||
per_channel_quant=self.per_channel_quant,
|
per_channel_quant=self.per_channel_quant,
|
||||||
block_shape=self.block_shape)
|
block_shape=self.block_shape,
|
||||||
|
B_bias=None # TODO support B_bias
|
||||||
|
)
|
||||||
|
|
||||||
self.activation(activation, intermediate_cache2,
|
self.activation(activation, intermediate_cache2,
|
||||||
intermediate_cache1.view(-1, N))
|
intermediate_cache1.view(-1, N))
|
||||||
@@ -1835,7 +1885,9 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
|||||||
use_int8_w8a16=self.use_int8_w8a16,
|
use_int8_w8a16=self.use_int8_w8a16,
|
||||||
use_int4_w4a16=self.use_int4_w4a16,
|
use_int4_w4a16=self.use_int4_w4a16,
|
||||||
per_channel_quant=self.per_channel_quant,
|
per_channel_quant=self.per_channel_quant,
|
||||||
block_shape=self.block_shape)
|
block_shape=self.block_shape,
|
||||||
|
B_bias=None # TODO support B_bias
|
||||||
|
)
|
||||||
|
|
||||||
return intermediate_cache3
|
return intermediate_cache3
|
||||||
|
|
||||||
|
|||||||
@@ -226,6 +226,8 @@ class MoEConfig:
|
|||||||
|
|
||||||
max_num_tokens: int = MOE_DP_CHUNK_SIZE
|
max_num_tokens: int = MOE_DP_CHUNK_SIZE
|
||||||
|
|
||||||
|
has_bias: bool = False
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def tp_size(self):
|
def tp_size(self):
|
||||||
return self.moe_parallel_config.tp_size
|
return self.moe_parallel_config.tp_size
|
||||||
@@ -443,6 +445,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|||||||
self.fused_experts = fused_experts # type: ignore
|
self.fused_experts = fused_experts # type: ignore
|
||||||
self.topk_indices_dtype = None
|
self.topk_indices_dtype = None
|
||||||
self.moe = moe
|
self.moe = moe
|
||||||
|
self.has_bias = self.moe.has_bias
|
||||||
|
|
||||||
self.rocm_aiter_moe_enabled = is_rocm_aiter_moe_enabled()
|
self.rocm_aiter_moe_enabled = is_rocm_aiter_moe_enabled()
|
||||||
if self.rocm_aiter_moe_enabled:
|
if self.rocm_aiter_moe_enabled:
|
||||||
@@ -502,6 +505,14 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|||||||
requires_grad=False)
|
requires_grad=False)
|
||||||
layer.register_parameter("w13_weight", w13_weight)
|
layer.register_parameter("w13_weight", w13_weight)
|
||||||
set_weight_attrs(w13_weight, extra_weight_attrs)
|
set_weight_attrs(w13_weight, extra_weight_attrs)
|
||||||
|
if self.has_bias:
|
||||||
|
w13_bias = torch.nn.Parameter(torch.zeros(
|
||||||
|
num_experts,
|
||||||
|
2 * intermediate_size_per_partition,
|
||||||
|
dtype=params_dtype),
|
||||||
|
requires_grad=False)
|
||||||
|
layer.register_parameter("w13_bias", w13_bias)
|
||||||
|
set_weight_attrs(w13_bias, extra_weight_attrs)
|
||||||
|
|
||||||
# down_proj (row parallel)
|
# down_proj (row parallel)
|
||||||
w2_weight = torch.nn.Parameter(torch.empty(
|
w2_weight = torch.nn.Parameter(torch.empty(
|
||||||
@@ -512,6 +523,13 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|||||||
requires_grad=False)
|
requires_grad=False)
|
||||||
layer.register_parameter("w2_weight", w2_weight)
|
layer.register_parameter("w2_weight", w2_weight)
|
||||||
set_weight_attrs(w2_weight, extra_weight_attrs)
|
set_weight_attrs(w2_weight, extra_weight_attrs)
|
||||||
|
if self.has_bias:
|
||||||
|
w2_bias = torch.nn.Parameter(torch.zeros(num_experts,
|
||||||
|
hidden_size,
|
||||||
|
dtype=params_dtype),
|
||||||
|
requires_grad=False)
|
||||||
|
layer.register_parameter("w2_bias", w2_bias)
|
||||||
|
set_weight_attrs(w2_bias, extra_weight_attrs)
|
||||||
|
|
||||||
def _maybe_pad_weight(self, weight: torch.Tensor) -> torch.Tensor:
|
def _maybe_pad_weight(self, weight: torch.Tensor) -> torch.Tensor:
|
||||||
# Pad the weight tensor. This is an optimization on ROCm platform, which
|
# Pad the weight tensor. This is an optimization on ROCm platform, which
|
||||||
@@ -634,6 +652,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|||||||
hidden_states=x,
|
hidden_states=x,
|
||||||
w1=layer.w13_weight,
|
w1=layer.w13_weight,
|
||||||
w2=layer.w2_weight,
|
w2=layer.w2_weight,
|
||||||
|
w1_bias=layer.w13_bias if self.has_bias else None,
|
||||||
|
w2_bias=layer.w2_bias if self.has_bias else None,
|
||||||
topk_weights=topk_weights,
|
topk_weights=topk_weights,
|
||||||
topk_ids=topk_ids,
|
topk_ids=topk_ids,
|
||||||
inplace=True,
|
inplace=True,
|
||||||
@@ -840,6 +860,7 @@ class FusedMoE(torch.nn.Module):
|
|||||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||||
apply_router_weight_on_input: bool = False,
|
apply_router_weight_on_input: bool = False,
|
||||||
activation: str = "silu",
|
activation: str = "silu",
|
||||||
|
has_bias: bool = False,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
if params_dtype is None:
|
if params_dtype is None:
|
||||||
@@ -920,6 +941,7 @@ class FusedMoE(torch.nn.Module):
|
|||||||
in_dtype=params_dtype,
|
in_dtype=params_dtype,
|
||||||
quant_dtype=quant_dtype,
|
quant_dtype=quant_dtype,
|
||||||
max_num_tokens=MOE_DP_CHUNK_SIZE,
|
max_num_tokens=MOE_DP_CHUNK_SIZE,
|
||||||
|
has_bias=has_bias,
|
||||||
)
|
)
|
||||||
self.moe_config = moe
|
self.moe_config = moe
|
||||||
self.quant_config = quant_config
|
self.quant_config = quant_config
|
||||||
|
|||||||
618
model_executor/models/gpt_oss.py
Normal file
618
model_executor/models/gpt_oss.py
Normal file
@@ -0,0 +1,618 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
from collections.abc import Iterable
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
|
from torch import nn
|
||||||
|
from transformers import GptOssConfig
|
||||||
|
|
||||||
|
from vllm import envs
|
||||||
|
from vllm.attention import Attention, AttentionType
|
||||||
|
from vllm.compilation.decorators import support_torch_compile
|
||||||
|
from vllm.config import CacheConfig, VllmConfig
|
||||||
|
from vllm.distributed import (get_ep_group, get_tensor_model_parallel_rank,
|
||||||
|
get_tensor_model_parallel_world_size)
|
||||||
|
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||||
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||||
|
from vllm.model_executor.layers.linear import (QKVParallelLinear,
|
||||||
|
RowParallelLinear)
|
||||||
|
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||||
|
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||||
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||||
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||||
|
ParallelLMHead, VocabParallelEmbedding)
|
||||||
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||||
|
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||||
|
from vllm.sequence import IntermediateTensors
|
||||||
|
from vllm.utils import cdiv
|
||||||
|
|
||||||
|
from .utils import extract_layer_index, maybe_prefix
|
||||||
|
|
||||||
|
|
||||||
|
class OAIAttention(nn.Module):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: GptOssConfig,
|
||||||
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
cache_config: Optional[CacheConfig] = None,
|
||||||
|
prefix: str = "",
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.layer_idx = extract_layer_index(prefix)
|
||||||
|
self.head_dim = config.head_dim
|
||||||
|
self.num_attention_heads = config.num_attention_heads
|
||||||
|
self.num_key_value_heads = config.num_key_value_heads
|
||||||
|
self.hidden_size = config.hidden_size
|
||||||
|
|
||||||
|
self.rotary_emb = get_rope(
|
||||||
|
self.head_dim,
|
||||||
|
rotary_dim=self.head_dim,
|
||||||
|
max_position=config.max_position_embeddings,
|
||||||
|
base=config.rope_theta,
|
||||||
|
dtype=torch.float32,
|
||||||
|
rope_scaling={
|
||||||
|
"rope_type":
|
||||||
|
"yarn",
|
||||||
|
"factor":
|
||||||
|
config.rope_scaling["factor"],
|
||||||
|
"original_max_position_embeddings":
|
||||||
|
config.rope_scaling["original_max_position_embeddings"],
|
||||||
|
"beta_fast":
|
||||||
|
config.rope_scaling["beta_fast"],
|
||||||
|
"beta_slow":
|
||||||
|
config.rope_scaling["beta_slow"],
|
||||||
|
},
|
||||||
|
is_neox_style=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
tp_size = get_tensor_model_parallel_world_size()
|
||||||
|
|
||||||
|
# attention_sink_dtype = (torch.float32 if envs.VLLM_USE_TRTLLM_ATTENTION
|
||||||
|
# else torch.bfloat16)
|
||||||
|
attention_sink_dtype = torch.bfloat16
|
||||||
|
self.sinks = torch.nn.Parameter(
|
||||||
|
torch.empty(config.num_attention_heads // tp_size,
|
||||||
|
dtype=attention_sink_dtype,
|
||||||
|
requires_grad=False))
|
||||||
|
|
||||||
|
self.norm = RMSNorm(config.hidden_size, eps=1e-5)
|
||||||
|
|
||||||
|
self.q_size = self.num_attention_heads * self.head_dim // tp_size
|
||||||
|
self.kv_size = self.num_key_value_heads * self.head_dim // tp_size
|
||||||
|
self.scaling = self.head_dim**-0.5
|
||||||
|
self.rope_theta = config.rope_theta
|
||||||
|
|
||||||
|
self.qkv = QKVParallelLinear(
|
||||||
|
hidden_size=self.hidden_size,
|
||||||
|
head_size=self.head_dim,
|
||||||
|
total_num_heads=self.num_attention_heads,
|
||||||
|
total_num_kv_heads=self.num_key_value_heads,
|
||||||
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.qkv_proj",
|
||||||
|
)
|
||||||
|
|
||||||
|
self.o_proj = RowParallelLinear(
|
||||||
|
input_size=self.num_attention_heads * self.head_dim,
|
||||||
|
output_size=self.hidden_size,
|
||||||
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.o_proj",
|
||||||
|
)
|
||||||
|
|
||||||
|
self.num_local_attention_heads = config.num_attention_heads // tp_size
|
||||||
|
self.num_local_key_value_heads = config.num_key_value_heads // tp_size
|
||||||
|
|
||||||
|
# Only apply sliding window to every other layer
|
||||||
|
sliding_window = (config.sliding_window if self.layer_idx %
|
||||||
|
2 == 0 else None)
|
||||||
|
self.attn = Attention(
|
||||||
|
self.num_local_attention_heads,
|
||||||
|
self.head_dim,
|
||||||
|
self.scaling,
|
||||||
|
num_kv_heads=self.num_local_key_value_heads,
|
||||||
|
cache_config=cache_config,
|
||||||
|
quant_config=quant_config,
|
||||||
|
per_layer_sliding_window=sliding_window,
|
||||||
|
attn_type=AttentionType.DECODER,
|
||||||
|
prefix=f"{prefix}.attn",
|
||||||
|
sinks=self.sinks,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, hidden_states: torch.Tensor,
|
||||||
|
positions: torch.Tensor) -> torch.Tensor:
|
||||||
|
t = self.norm(hidden_states)
|
||||||
|
|
||||||
|
qkv, _ = self.qkv(t)
|
||||||
|
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||||
|
q, k = self.rotary_emb(positions, q, k)
|
||||||
|
v = v.contiguous()
|
||||||
|
attn_output = self.attn(q, k, v)
|
||||||
|
output, _ = self.o_proj(attn_output)
|
||||||
|
|
||||||
|
return output + hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class MLPBlock(torch.nn.Module):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: GptOssConfig,
|
||||||
|
layer_idx: int,
|
||||||
|
quant_config: QuantizationConfig,
|
||||||
|
prefix: str = "",
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.layer_idx = layer_idx
|
||||||
|
self.num_experts = config.num_local_experts
|
||||||
|
self.experts_per_token = config.num_experts_per_tok
|
||||||
|
self.world_size = dist.get_world_size() if dist.is_initialized() else 1
|
||||||
|
self.norm = RMSNorm(config.hidden_size, eps=1e-5)
|
||||||
|
self.router = torch.nn.Linear(config.hidden_size,
|
||||||
|
config.num_local_experts,
|
||||||
|
dtype=torch.bfloat16)
|
||||||
|
assert config.intermediate_size % self.world_size == 0
|
||||||
|
self.experts = FusedMoE(num_experts=config.num_local_experts,
|
||||||
|
top_k=config.num_experts_per_tok,
|
||||||
|
hidden_size=config.hidden_size,
|
||||||
|
intermediate_size=config.intermediate_size,
|
||||||
|
reduce_results=True,
|
||||||
|
renormalize=True,
|
||||||
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.experts",
|
||||||
|
apply_router_weight_on_input=False,
|
||||||
|
has_bias=True,
|
||||||
|
activation="swiglu_oai")
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
t = self.norm(x)
|
||||||
|
g = self.router(t)
|
||||||
|
t = self.experts(hidden_states=t, router_logits=g)
|
||||||
|
return x + t
|
||||||
|
|
||||||
|
|
||||||
|
class TransformerBlock(torch.nn.Module):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: GptOssConfig,
|
||||||
|
quant_config: QuantizationConfig,
|
||||||
|
prefix: str = "",
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.layer_idx = extract_layer_index(prefix)
|
||||||
|
self.attn = OAIAttention(config, prefix=f"{prefix}.attn")
|
||||||
|
self.mlp = MLPBlock(config,
|
||||||
|
self.layer_idx,
|
||||||
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.mlp")
|
||||||
|
|
||||||
|
def forward(self, hidden_states: torch.Tensor,
|
||||||
|
positions: torch.Tensor) -> torch.Tensor:
|
||||||
|
attn_output = self.attn(hidden_states, positions)
|
||||||
|
output = self.mlp(attn_output)
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
@support_torch_compile
|
||||||
|
class GptOssModel(nn.Module):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
vllm_config: VllmConfig,
|
||||||
|
prefix: str = "",
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.config = vllm_config.model_config.hf_config
|
||||||
|
self.quant_config = vllm_config.quant_config
|
||||||
|
self.config.hidden_size = self.config.hidden_size
|
||||||
|
self.embedding = VocabParallelEmbedding(
|
||||||
|
self.config.vocab_size,
|
||||||
|
self.config.hidden_size,
|
||||||
|
)
|
||||||
|
self.layers = torch.nn.ModuleList([
|
||||||
|
TransformerBlock(
|
||||||
|
self.config,
|
||||||
|
quant_config=self.quant_config,
|
||||||
|
prefix=maybe_prefix(prefix, f"block.{layer_idx}"),
|
||||||
|
) for layer_idx in range(self.config.num_hidden_layers)
|
||||||
|
])
|
||||||
|
self.norm = RMSNorm(self.config.hidden_size, eps=1e-5)
|
||||||
|
|
||||||
|
def forward(self, input_ids: torch.Tensor,
|
||||||
|
positions: torch.Tensor) -> torch.Tensor:
|
||||||
|
x = self.embedding(input_ids)
|
||||||
|
for layer in self.layers:
|
||||||
|
x = layer(x, positions)
|
||||||
|
x = self.norm(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class GptOssForCausalLM(nn.Module):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
vllm_config: VllmConfig,
|
||||||
|
prefix: str = "",
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.vllm_config = vllm_config
|
||||||
|
self.model_config = vllm_config.model_config.hf_config
|
||||||
|
self.model = GptOssModel(
|
||||||
|
vllm_config=vllm_config,
|
||||||
|
prefix=maybe_prefix(prefix, "model"),
|
||||||
|
)
|
||||||
|
self.lm_head = ParallelLMHead(
|
||||||
|
self.model_config.vocab_size,
|
||||||
|
self.model_config.hidden_size,
|
||||||
|
)
|
||||||
|
self.logits_processor = LogitsProcessor(self.model_config.vocab_size)
|
||||||
|
|
||||||
|
def forward(self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||||
|
inputs_embeds: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||||
|
assert intermediate_tensors is None
|
||||||
|
assert inputs_embeds is None
|
||||||
|
return self.model(input_ids, positions)
|
||||||
|
|
||||||
|
def compute_logits(self, hidden_states: torch.Tensor,
|
||||||
|
sampling_metadata: SamplingMetadata) -> torch.Tensor:
|
||||||
|
logits = self.logits_processor(self.lm_head, hidden_states,
|
||||||
|
sampling_metadata)
|
||||||
|
return logits
|
||||||
|
|
||||||
|
def _load_weights_mxfp4(
|
||||||
|
self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
|
||||||
|
rename_mapping = {
|
||||||
|
"self_attn": "attn",
|
||||||
|
"input_layernorm.weight": "attn.norm.weight",
|
||||||
|
"post_attention_layernorm.weight": "mlp.norm.weight",
|
||||||
|
"embed_tokens": "embedding",
|
||||||
|
}
|
||||||
|
|
||||||
|
def maybe_rename(name: str) -> str:
|
||||||
|
for remap_name, new_name in rename_mapping.items():
|
||||||
|
if remap_name in name:
|
||||||
|
return name.replace(remap_name, new_name)
|
||||||
|
return name
|
||||||
|
|
||||||
|
params_dict = dict(self.named_parameters())
|
||||||
|
loaded_params: set[str] = set()
|
||||||
|
mxfp4_block = 32
|
||||||
|
|
||||||
|
tp_rank = get_tensor_model_parallel_rank()
|
||||||
|
tp_size = get_tensor_model_parallel_world_size()
|
||||||
|
intermediate_size = self.model_config.intermediate_size
|
||||||
|
intermediate_size_block = intermediate_size // mxfp4_block
|
||||||
|
per_rank_intermediate_size_block = cdiv(intermediate_size_block,
|
||||||
|
tp_size)
|
||||||
|
per_rank_intermediate_size = (per_rank_intermediate_size_block *
|
||||||
|
mxfp4_block)
|
||||||
|
|
||||||
|
# Calculate common slicing bounds for current rank
|
||||||
|
tp_rank_start = tp_rank * per_rank_intermediate_size
|
||||||
|
tp_rank_end = min((tp_rank + 1) * per_rank_intermediate_size,
|
||||||
|
intermediate_size)
|
||||||
|
|
||||||
|
# Attention heads per rank
|
||||||
|
heads_per_rank = self.model_config.num_attention_heads // tp_size
|
||||||
|
head_start = tp_rank * heads_per_rank
|
||||||
|
|
||||||
|
use_ep = self.vllm_config.parallel_config.enable_expert_parallel
|
||||||
|
ep_size = get_ep_group().world_size
|
||||||
|
ep_rank = get_ep_group().rank
|
||||||
|
num_experts = self.model_config.num_local_experts
|
||||||
|
experts_per_rank = num_experts // ep_size
|
||||||
|
ep_rank_start = ep_rank * experts_per_rank
|
||||||
|
ep_rank_end = (ep_rank + 1) * experts_per_rank
|
||||||
|
|
||||||
|
for name, weight in weights:
|
||||||
|
# FIXME(woosuk): Remove this after testing.
|
||||||
|
weight = weight.cuda()
|
||||||
|
|
||||||
|
if "gate_up_proj_blocks" in name:
|
||||||
|
# Handle MLP gate and up projection weights
|
||||||
|
new_name = name.replace("gate_up_proj_blocks", "w13_weight")
|
||||||
|
|
||||||
|
# flat weight from (E, 2 * N, block_size, entry_per_block)
|
||||||
|
# to (E, 2 * N, -1), shouldn't trigger copy for contiguous
|
||||||
|
weight = weight.view(num_experts, 2 * intermediate_size,
|
||||||
|
-1).contiguous()
|
||||||
|
|
||||||
|
# Extract gate and up projection parts
|
||||||
|
# since the weight is shuffled, we can slice directly
|
||||||
|
if use_ep:
|
||||||
|
narrow_weight = weight[ep_rank_start:ep_rank_end, ...]
|
||||||
|
else:
|
||||||
|
narrow_weight = weight[:,
|
||||||
|
2 * tp_rank_start:2 * tp_rank_end,
|
||||||
|
...]
|
||||||
|
|
||||||
|
param = params_dict[new_name]
|
||||||
|
weight_loader = getattr(param, "weight_loader",
|
||||||
|
default_weight_loader)
|
||||||
|
weight_loader(param,
|
||||||
|
narrow_weight,
|
||||||
|
weight_name=new_name,
|
||||||
|
shard_id=None,
|
||||||
|
expert_id=None)
|
||||||
|
loaded_params.add(new_name)
|
||||||
|
|
||||||
|
elif "down_proj_blocks" in name:
|
||||||
|
# Handle MLP down projection weights
|
||||||
|
new_name = name.replace("down_proj_blocks", "w2_weight")
|
||||||
|
# same flatten here, but since 2 mx4 value are packed in 1
|
||||||
|
# uint8, divide by 2
|
||||||
|
weight = weight.view(num_experts, -1,
|
||||||
|
intermediate_size // 2).contiguous()
|
||||||
|
if use_ep:
|
||||||
|
narrow_weight = weight[ep_rank_start:ep_rank_end, ...]
|
||||||
|
else:
|
||||||
|
narrow_weight = weight[...,
|
||||||
|
tp_rank_start // 2:tp_rank_end // 2]
|
||||||
|
|
||||||
|
param = params_dict[new_name]
|
||||||
|
weight_loader = getattr(param, "weight_loader",
|
||||||
|
default_weight_loader)
|
||||||
|
weight_loader(param,
|
||||||
|
narrow_weight,
|
||||||
|
weight_name=new_name,
|
||||||
|
shard_id=None,
|
||||||
|
expert_id=None)
|
||||||
|
loaded_params.add(new_name)
|
||||||
|
|
||||||
|
elif "gate_up_proj_scales" in name:
|
||||||
|
# Handle MLP gate and up projection weights scale
|
||||||
|
new_name = name.replace("gate_up_proj_scales",
|
||||||
|
"w13_weight_scale")
|
||||||
|
if use_ep:
|
||||||
|
narrow_weight = weight[ep_rank_start:ep_rank_end, ...]
|
||||||
|
else:
|
||||||
|
narrow_weight = weight[:,
|
||||||
|
2 * tp_rank_start:2 * tp_rank_end,
|
||||||
|
...]
|
||||||
|
|
||||||
|
param = params_dict[new_name]
|
||||||
|
weight_loader = getattr(param, "weight_loader",
|
||||||
|
default_weight_loader)
|
||||||
|
weight_loader(param,
|
||||||
|
narrow_weight,
|
||||||
|
weight_name=new_name,
|
||||||
|
shard_id=None,
|
||||||
|
expert_id=None)
|
||||||
|
loaded_params.add(new_name)
|
||||||
|
|
||||||
|
elif "down_proj_scales" in name:
|
||||||
|
# Handle MLP down projection weights
|
||||||
|
new_name = name.replace("down_proj_scales", "w2_weight_scale")
|
||||||
|
if use_ep:
|
||||||
|
narrow_weight = weight[ep_rank_start:ep_rank_end, ...]
|
||||||
|
else:
|
||||||
|
narrow_weight = weight[..., tp_rank_start //
|
||||||
|
mxfp4_block:tp_rank_end //
|
||||||
|
mxfp4_block]
|
||||||
|
|
||||||
|
param = params_dict[new_name]
|
||||||
|
weight_loader = getattr(param, "weight_loader",
|
||||||
|
default_weight_loader)
|
||||||
|
weight_loader(param,
|
||||||
|
narrow_weight,
|
||||||
|
weight_name=new_name,
|
||||||
|
shard_id=None,
|
||||||
|
expert_id=None)
|
||||||
|
loaded_params.add(new_name)
|
||||||
|
elif "gate_up_proj_bias" in name:
|
||||||
|
# Handle MLP gate and up projection biases
|
||||||
|
new_name = name.replace("gate_up_proj_bias", "w13_bias")
|
||||||
|
|
||||||
|
# Extract gate and up projection bias parts
|
||||||
|
if use_ep:
|
||||||
|
narrow_weight = weight[ep_rank_start:ep_rank_end, ...]
|
||||||
|
else:
|
||||||
|
narrow_weight = weight[:,
|
||||||
|
2 * tp_rank_start:2 * tp_rank_end]
|
||||||
|
|
||||||
|
param = params_dict[new_name]
|
||||||
|
weight_loader = getattr(param, "weight_loader",
|
||||||
|
default_weight_loader)
|
||||||
|
weight_loader(param,
|
||||||
|
narrow_weight,
|
||||||
|
weight_name=new_name,
|
||||||
|
shard_id=None,
|
||||||
|
expert_id=None)
|
||||||
|
loaded_params.add(new_name)
|
||||||
|
|
||||||
|
elif "down_proj_bias" in name:
|
||||||
|
# Handle MLP down projection bias
|
||||||
|
new_name = name.replace("down_proj_bias", "w2_bias")
|
||||||
|
param = params_dict[new_name]
|
||||||
|
weight_loader = getattr(param, "weight_loader",
|
||||||
|
default_weight_loader)
|
||||||
|
if use_ep:
|
||||||
|
weight = weight[ep_rank_start:ep_rank_end, ...]
|
||||||
|
else:
|
||||||
|
# (only load on rank 0 to avoid duplication)
|
||||||
|
if tp_rank != 0:
|
||||||
|
weight.zero_()
|
||||||
|
weight_loader(param,
|
||||||
|
weight,
|
||||||
|
weight_name=new_name,
|
||||||
|
shard_id=None,
|
||||||
|
expert_id=None)
|
||||||
|
loaded_params.add(new_name)
|
||||||
|
elif "sinks" in name:
|
||||||
|
# Handle attention sinks (distributed across ranks)
|
||||||
|
name = name.replace("self_attn", "attn")
|
||||||
|
param = params_dict[name]
|
||||||
|
narrow_weight = weight.narrow(0, head_start, heads_per_rank)
|
||||||
|
param.data.copy_(narrow_weight)
|
||||||
|
loaded_params.add(name)
|
||||||
|
elif "q_proj" in name or "k_proj" in name or "v_proj" in name:
|
||||||
|
shard_id = ("q" if "q_proj" in name else
|
||||||
|
"k" if "k_proj" in name else "v")
|
||||||
|
name = name.replace("self_attn", "attn")
|
||||||
|
param_name = name.replace(f"{shard_id}_proj", "qkv")
|
||||||
|
param = params_dict[param_name]
|
||||||
|
weight_loader = param.weight_loader
|
||||||
|
weight_loader(param, weight, loaded_shard_id=shard_id)
|
||||||
|
loaded_params.add(param_name)
|
||||||
|
else:
|
||||||
|
# Handle all other weights with potential renaming
|
||||||
|
renamed_name = maybe_rename(name)
|
||||||
|
if renamed_name not in params_dict:
|
||||||
|
continue
|
||||||
|
param = params_dict[renamed_name]
|
||||||
|
weight_loader = getattr(param, "weight_loader",
|
||||||
|
default_weight_loader)
|
||||||
|
weight_loader(param, weight)
|
||||||
|
loaded_params.add(renamed_name)
|
||||||
|
|
||||||
|
return loaded_params
|
||||||
|
|
||||||
|
def _load_weights_other(
|
||||||
|
self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
|
||||||
|
rename_mapping = {
|
||||||
|
"self_attn": "attn",
|
||||||
|
"input_layernorm.weight": "attn.norm.weight",
|
||||||
|
"post_attention_layernorm.weight": "mlp.norm.weight",
|
||||||
|
"embed_tokens": "embedding",
|
||||||
|
}
|
||||||
|
|
||||||
|
def maybe_rename(name: str) -> str:
|
||||||
|
for remap_name, new_name in rename_mapping.items():
|
||||||
|
if remap_name in name:
|
||||||
|
return name.replace(remap_name, new_name)
|
||||||
|
return name
|
||||||
|
|
||||||
|
params_dict = dict(self.named_parameters())
|
||||||
|
loaded_params: set[str] = set()
|
||||||
|
|
||||||
|
tp_rank = get_tensor_model_parallel_rank()
|
||||||
|
tp_size = get_tensor_model_parallel_world_size()
|
||||||
|
intermediate_size = self.model_config.intermediate_size
|
||||||
|
|
||||||
|
per_rank_intermediate_size = cdiv(intermediate_size, tp_size)
|
||||||
|
# Calculate common slicing bounds for current rank
|
||||||
|
tp_rank_start = tp_rank * per_rank_intermediate_size
|
||||||
|
tp_rank_end = min((tp_rank + 1) * per_rank_intermediate_size,
|
||||||
|
intermediate_size)
|
||||||
|
|
||||||
|
# Attention heads per rank
|
||||||
|
heads_per_rank = self.model_config.num_attention_heads // tp_size
|
||||||
|
head_start = tp_rank * heads_per_rank
|
||||||
|
|
||||||
|
use_ep = self.vllm_config.parallel_config.enable_expert_parallel
|
||||||
|
ep_size = get_ep_group().world_size
|
||||||
|
ep_rank = get_ep_group().rank
|
||||||
|
num_experts = self.model_config.num_local_experts
|
||||||
|
experts_per_rank = num_experts // ep_size
|
||||||
|
ep_rank_start = ep_rank * experts_per_rank
|
||||||
|
ep_rank_end = (ep_rank + 1) * experts_per_rank
|
||||||
|
|
||||||
|
for name, weight in weights:
|
||||||
|
if ".experts.gate_up_proj" in name and "bias" not in name:
|
||||||
|
# Handle MLP gate and up projection weights
|
||||||
|
new_name = name.replace(".experts.gate_up_proj",
|
||||||
|
".experts.w13_weight")
|
||||||
|
|
||||||
|
# Extract gate and up projection parts
|
||||||
|
# since the weight is shuffled, we can slice directly
|
||||||
|
if use_ep:
|
||||||
|
narrow_weight = weight[ep_rank_start:ep_rank_end, ...]
|
||||||
|
else:
|
||||||
|
narrow_weight = weight[:, :,
|
||||||
|
2 * tp_rank_start:2 * tp_rank_end]
|
||||||
|
|
||||||
|
narrow_weight = narrow_weight.permute(0, 2, 1).contiguous()
|
||||||
|
param = params_dict[new_name]
|
||||||
|
|
||||||
|
param.copy_(narrow_weight)
|
||||||
|
loaded_params.add(new_name)
|
||||||
|
|
||||||
|
elif ".experts.down_proj" in name and "bias" not in name:
|
||||||
|
# Handle MLP down projection weights
|
||||||
|
new_name = name.replace(".experts.down_proj",
|
||||||
|
".experts.w2_weight")
|
||||||
|
|
||||||
|
if use_ep:
|
||||||
|
narrow_weight = weight[ep_rank_start:ep_rank_end, ...]
|
||||||
|
else:
|
||||||
|
narrow_weight = weight[:, tp_rank_start:tp_rank_end, :]
|
||||||
|
narrow_weight = narrow_weight.permute(0, 2, 1).contiguous()
|
||||||
|
param = params_dict[new_name]
|
||||||
|
|
||||||
|
param.copy_(narrow_weight)
|
||||||
|
loaded_params.add(new_name)
|
||||||
|
|
||||||
|
elif "gate_up_proj_bias" in name:
|
||||||
|
# Handle MLP gate and up projection biases
|
||||||
|
new_name = name.replace("gate_up_proj_bias", "w13_bias")
|
||||||
|
|
||||||
|
# Extract gate and up projection bias parts
|
||||||
|
if use_ep:
|
||||||
|
narrow_weight = weight[ep_rank_start:ep_rank_end, ...]
|
||||||
|
else:
|
||||||
|
narrow_weight = weight[:,
|
||||||
|
2 * tp_rank_start:2 * tp_rank_end]
|
||||||
|
|
||||||
|
param = params_dict[new_name]
|
||||||
|
|
||||||
|
param.copy_(narrow_weight)
|
||||||
|
loaded_params.add(new_name)
|
||||||
|
|
||||||
|
elif "down_proj_bias" in name:
|
||||||
|
# Handle MLP down projection bias
|
||||||
|
new_name = name.replace("down_proj_bias", "w2_bias")
|
||||||
|
|
||||||
|
if use_ep:
|
||||||
|
weight = weight[ep_rank_start:ep_rank_end, ...]
|
||||||
|
else:
|
||||||
|
# (only load on rank 0 to avoid duplication)
|
||||||
|
if tp_rank != 0:
|
||||||
|
weight.zero_()
|
||||||
|
param = params_dict[new_name]
|
||||||
|
param.copy_(weight)
|
||||||
|
loaded_params.add(new_name)
|
||||||
|
elif "sinks" in name:
|
||||||
|
# Handle attention sinks (distributed across ranks)
|
||||||
|
name = name.replace("self_attn", "attn")
|
||||||
|
param = params_dict[name]
|
||||||
|
narrow_weight = weight.narrow(0, head_start, heads_per_rank)
|
||||||
|
param.data.copy_(narrow_weight)
|
||||||
|
loaded_params.add(name)
|
||||||
|
elif "q_proj" in name or "k_proj" in name or "v_proj" in name:
|
||||||
|
shard_id = ("q" if "q_proj" in name else
|
||||||
|
"k" if "k_proj" in name else "v")
|
||||||
|
name = name.replace("self_attn", "attn")
|
||||||
|
param_name = name.replace(f"{shard_id}_proj", "qkv")
|
||||||
|
param = params_dict[param_name]
|
||||||
|
weight_loader = param.weight_loader
|
||||||
|
weight_loader(param, weight, loaded_shard_id=shard_id)
|
||||||
|
loaded_params.add(param_name)
|
||||||
|
else:
|
||||||
|
# Handle all other weights with potential renaming
|
||||||
|
|
||||||
|
renamed_name = maybe_rename(name)
|
||||||
|
if renamed_name not in params_dict:
|
||||||
|
continue
|
||||||
|
param = params_dict[renamed_name]
|
||||||
|
weight_loader = getattr(param, "weight_loader",
|
||||||
|
default_weight_loader)
|
||||||
|
weight_loader(param, weight)
|
||||||
|
loaded_params.add(renamed_name)
|
||||||
|
|
||||||
|
return loaded_params
|
||||||
|
|
||||||
|
def load_weights(self, weights: Iterable[tuple[str,
|
||||||
|
torch.Tensor]]) -> set[str]:
|
||||||
|
quant_method = (self.model_config.quantization_config['quant_method']
|
||||||
|
if hasattr(self.model_config, "quantization_config")
|
||||||
|
else None)
|
||||||
|
if quant_method == "mxfp4":
|
||||||
|
return self._load_weights_mxfp4(weights)
|
||||||
|
else:
|
||||||
|
return self._load_weights_other(weights)
|
||||||
@@ -61,6 +61,7 @@ _TEXT_GENERATION_MODELS = {
|
|||||||
"Gemma3ForCausalLM": ("gemma3", "Gemma3ForCausalLM"),
|
"Gemma3ForCausalLM": ("gemma3", "Gemma3ForCausalLM"),
|
||||||
"GlmForCausalLM": ("glm", "GlmForCausalLM"),
|
"GlmForCausalLM": ("glm", "GlmForCausalLM"),
|
||||||
"Glm4ForCausalLM": ("glm4", "Glm4ForCausalLM"),
|
"Glm4ForCausalLM": ("glm4", "Glm4ForCausalLM"),
|
||||||
|
"GptOssForCausalLM": ("gpt_oss", "GptOssForCausalLM"),
|
||||||
"GPT2LMHeadModel": ("gpt2", "GPT2LMHeadModel"),
|
"GPT2LMHeadModel": ("gpt2", "GPT2LMHeadModel"),
|
||||||
"GPTBigCodeForCausalLM": ("gpt_bigcode", "GPTBigCodeForCausalLM"),
|
"GPTBigCodeForCausalLM": ("gpt_bigcode", "GPTBigCodeForCausalLM"),
|
||||||
"GPTJForCausalLM": ("gpt_j", "GPTJForCausalLM"),
|
"GPTJForCausalLM": ("gpt_j", "GPTJForCausalLM"),
|
||||||
|
|||||||
@@ -126,7 +126,8 @@ def use_rocm_custom_paged_attention(
|
|||||||
max_seq_len: int,
|
max_seq_len: int,
|
||||||
sliding_window: int,
|
sliding_window: int,
|
||||||
kv_cache_dtype: str,
|
kv_cache_dtype: str,
|
||||||
alibi_slopes: Optional[torch.Tensor] = None) -> bool:
|
alibi_slopes: Optional[torch.Tensor] = None,
|
||||||
|
sinks: Optional[torch.Tensor] = None) -> bool:
|
||||||
|
|
||||||
GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName
|
GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName
|
||||||
ON_GFX9 = any(arch in GPU_ARCH for arch in ["gfx90a", "gfx942", "gfx950"])
|
ON_GFX9 = any(arch in GPU_ARCH for arch in ["gfx90a", "gfx942", "gfx950"])
|
||||||
@@ -143,7 +144,7 @@ def use_rocm_custom_paged_attention(
|
|||||||
and (gqa_ratio >= 1 and gqa_ratio <= 16)
|
and (gqa_ratio >= 1 and gqa_ratio <= 16)
|
||||||
and max_seq_len <= 32768 and (envs.VLLM_ROCM_CUSTOM_PAGED_ATTN)
|
and max_seq_len <= 32768 and (envs.VLLM_ROCM_CUSTOM_PAGED_ATTN)
|
||||||
and not (envs.VLLM_ROCM_USE_AITER_PAGED_ATTN
|
and not (envs.VLLM_ROCM_USE_AITER_PAGED_ATTN
|
||||||
and envs.VLLM_ROCM_USE_AITER))
|
and envs.VLLM_ROCM_USE_AITER) and sinks is None)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
return (ON_GFX11_GFX12 and (not envs.VLLM_USE_V1 or sliding_window == 0
|
return (ON_GFX11_GFX12 and (not envs.VLLM_USE_V1 or sliding_window == 0
|
||||||
@@ -153,7 +154,7 @@ def use_rocm_custom_paged_attention(
|
|||||||
and (gqa_ratio >= 3 and gqa_ratio <= 16)
|
and (gqa_ratio >= 3 and gqa_ratio <= 16)
|
||||||
and max_seq_len <= 32768 and alibi_slopes is None
|
and max_seq_len <= 32768 and alibi_slopes is None
|
||||||
and kv_cache_dtype == "auto"
|
and kv_cache_dtype == "auto"
|
||||||
and envs.VLLM_ROCM_CUSTOM_PAGED_ATTN)
|
and envs.VLLM_ROCM_CUSTOM_PAGED_ATTN and sinks is None)
|
||||||
|
|
||||||
|
|
||||||
class RocmPlatform(Platform):
|
class RocmPlatform(Platform):
|
||||||
|
|||||||
@@ -73,7 +73,7 @@ IMAGE_TOKEN = "<image>"
|
|||||||
IMAGE_ATOM_ID = -300
|
IMAGE_ATOM_ID = -300
|
||||||
IMAGE_INDICATOR_IDS = [-301, -302, -303, -304, -305]
|
IMAGE_INDICATOR_IDS = [-301, -302, -303, -304, -305]
|
||||||
|
|
||||||
AutoConfig.register("aimv2", AIMv2Config)
|
AutoConfig.register("aimv2", AIMv2Config, exist_ok=True)
|
||||||
|
|
||||||
|
|
||||||
# ----------------------------------------------------------------------
|
# ----------------------------------------------------------------------
|
||||||
|
|||||||
@@ -90,6 +90,7 @@ class TritonAttentionImpl(AttentionImpl):
|
|||||||
attn_type: AttentionType = AttentionType.DECODER,
|
attn_type: AttentionType = AttentionType.DECODER,
|
||||||
kv_sharing_target_layer_name: Optional[int] = None,
|
kv_sharing_target_layer_name: Optional[int] = None,
|
||||||
use_irope: bool = False,
|
use_irope: bool = False,
|
||||||
|
sinks: Optional[torch.Tensor] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
if blocksparse_params is not None:
|
if blocksparse_params is not None:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@@ -133,6 +134,13 @@ class TritonAttentionImpl(AttentionImpl):
|
|||||||
self.force_prefill_decode_attn = \
|
self.force_prefill_decode_attn = \
|
||||||
envs.VLLM_V1_USE_PREFILL_DECODE_ATTENTION
|
envs.VLLM_V1_USE_PREFILL_DECODE_ATTENTION
|
||||||
|
|
||||||
|
self.sinks = sinks
|
||||||
|
if sinks is not None:
|
||||||
|
assert sinks.shape[0] == num_heads, (
|
||||||
|
"Sinks must have the same number of heads as the number of "
|
||||||
|
f"heads in the layer. Sinks shape: {sinks.shape}, "
|
||||||
|
f"num_heads: {num_heads}.")
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
layer: torch.nn.Module,
|
layer: torch.nn.Module,
|
||||||
@@ -257,7 +265,8 @@ class TritonAttentionImpl(AttentionImpl):
|
|||||||
v_scale=layer._v_scale,
|
v_scale=layer._v_scale,
|
||||||
alibi_slopes=self.alibi_slopes,
|
alibi_slopes=self.alibi_slopes,
|
||||||
sliding_window=self.sliding_window[0],
|
sliding_window=self.sliding_window[0],
|
||||||
sm_scale=self.scale)
|
sm_scale=self.scale,
|
||||||
|
sinks=self.sinks)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
descale_shape = (cu_seqlens_q.shape[0] - 1, key.shape[1])
|
descale_shape = (cu_seqlens_q.shape[0] - 1, key.shape[1])
|
||||||
@@ -280,6 +289,7 @@ class TritonAttentionImpl(AttentionImpl):
|
|||||||
q_descale=None, # Not supported
|
q_descale=None, # Not supported
|
||||||
k_descale=layer._k_scale.expand(descale_shape),
|
k_descale=layer._k_scale.expand(descale_shape),
|
||||||
v_descale=layer._v_scale.expand(descale_shape),
|
v_descale=layer._v_scale.expand(descale_shape),
|
||||||
|
sinks=self.sinks,
|
||||||
)
|
)
|
||||||
|
|
||||||
return output
|
return output
|
||||||
|
|||||||
Reference in New Issue
Block a user