diff --git a/attention/ops/chunked_prefill_paged_decode.py b/attention/ops/chunked_prefill_paged_decode.py index 4f83934..b079ab1 100644 --- a/attention/ops/chunked_prefill_paged_decode.py +++ b/attention/ops/chunked_prefill_paged_decode.py @@ -28,6 +28,7 @@ def kernel_paged_attention_2d( query_ptr, # [num_tokens, num_query_heads, head_size] key_cache_ptr, # [num_blks, num_kv_heads, head_size // x, blk_size, x] value_cache_ptr, # [num_blks, num_kv_heads, head_size, blk_size] + sink_ptr, # [num_query_heads] block_tables_ptr, # [num_seqs, max_num_blocks_per_seq] seq_lens_ptr, # [num_seqs] alibi_slopes_ptr, # [num_query_heads] @@ -59,6 +60,7 @@ def kernel_paged_attention_2d( stride_v_cache_3: tl.int64, # int filter_by_query_len: tl.constexpr, # bool query_start_len_ptr, # [num_seqs+1] + USE_SINKS: tl.constexpr, # bool ): seq_idx = tl.program_id(0) kv_head_idx = tl.program_id(1) @@ -95,7 +97,18 @@ def kernel_paged_attention_2d( block_table_offset = seq_idx * block_table_stride - M = tl.full([num_queries_per_kv_padded], float("-inf"), dtype=tl.float32) + if not USE_SINKS: + M = tl.full([num_queries_per_kv_padded], + float("-inf"), + dtype=tl.float32) + else: + M = tl.load( + sink_ptr + query_head_idx, + mask=head_mask, + other=float("-inf"), + ).to(dtype=tl.float32) + # M = tl.full([num_queries_per_kv_padded], float("-inf"), dtype=tl.float32) + L = tl.full([num_queries_per_kv_padded], 1.0, dtype=tl.float32) acc = tl.zeros([num_queries_per_kv_padded, HEAD_SIZE_PADDED], dtype=tl.float32) @@ -223,6 +236,8 @@ def chunked_prefill_paged_decode( alibi_slopes=None, sliding_window=None, sm_scale=None, + # Optional tensor for sinks + sinks=None, ): if sm_scale is None: @@ -253,6 +268,7 @@ def chunked_prefill_paged_decode( sliding_window=sliding_window, sm_scale=sm_scale, skip_decode=True, + sinks=sinks, ) block_size = value_cache.shape[3] @@ -285,7 +301,7 @@ def chunked_prefill_paged_decode( block_size, num_queries_per_kv, max_seq_len, sliding_window, - kv_cache_dtype, alibi_slopes) + kv_cache_dtype, alibi_slopes, sinks,) if use_custom: _PARTITION_SIZE_ROCM = 256 max_num_partitions = ((max_seq_len + _PARTITION_SIZE_ROCM - 1) // @@ -334,6 +350,7 @@ def chunked_prefill_paged_decode( query_ptr=query, key_cache_ptr=key_cache, value_cache_ptr=value_cache, + sink_ptr=sinks, block_tables_ptr=block_table, seq_lens_ptr=seq_lens, alibi_slopes_ptr=alibi_slopes, @@ -365,4 +382,5 @@ def chunked_prefill_paged_decode( stride_v_cache_3=value_cache.stride(3), filter_by_query_len=True, query_start_len_ptr=query_start_loc, + USE_SINKS=sinks is not None, ) diff --git a/attention/ops/triton_unified_attention.py b/attention/ops/triton_unified_attention.py index 92c09e6..585238e 100644 --- a/attention/ops/triton_unified_attention.py +++ b/attention/ops/triton_unified_attention.py @@ -34,6 +34,7 @@ def kernel_unified_attention_2d( query_ptr, # [num_tokens, num_query_heads, head_size] key_cache_ptr, # [num_blks, blk_size, num_kv_heads, head_size] value_cache_ptr, # [num_blks, blk_size, num_kv_heads, head_size] + sink_ptr, # [num_query_heads] block_tables_ptr, # [num_seqs, max_num_blocks_per_seq] seq_lens_ptr, # [num_seqs] alibi_slopes_ptr, # [num_query_heads] @@ -53,6 +54,7 @@ def kernel_unified_attention_2d( HEAD_SIZE_PADDED: tl.constexpr, # int, must be power of 2 USE_ALIBI_SLOPES: tl.constexpr, # bool USE_SOFTCAP: tl.constexpr, # bool + USE_SINKS: tl.constexpr, # bool SLIDING_WINDOW: tl.constexpr, # int stride_k_cache_0: tl.int64, # int stride_k_cache_1: tl.int64, # int @@ -119,7 +121,16 @@ def kernel_unified_attention_2d( block_table_offset = seq_idx * block_table_stride - M = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32) + if not USE_SINKS: + M = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32) + else: + M = tl.load( + sink_ptr + query_offset_1, + mask=query_mask_1, + other=float("-inf"), + ).to(dtype=tl.float32) + # M = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32) + L = tl.full([BLOCK_M], 1.0, dtype=tl.float32) acc = tl.zeros([BLOCK_M, HEAD_SIZE_PADDED], dtype=tl.float32) @@ -260,6 +271,8 @@ def unified_attention( k_descale, v_descale, alibi_slopes=None, + # Optional tensor for sinks + sinks=None, ): assert causal, "Only causal attention is supported" assert q_descale is None, "Q scales not supported" @@ -267,6 +280,10 @@ def unified_attention( block_size = v.shape[1] assert q.element_size() >= 2 or block_size >= 32, \ "Block size must be at least 32 for fp8" + + if sinks is not None: + assert sinks.shape[0] == q.shape[1], \ + "Sinks must be num_query_heads size" use_alibi_slopes = alibi_slopes is not None @@ -299,6 +316,7 @@ def unified_attention( query_ptr=q, key_cache_ptr=k, value_cache_ptr=v, + sink_ptr=sinks, block_tables_ptr=block_table, seq_lens_ptr=seqused_k, alibi_slopes_ptr=alibi_slopes, @@ -318,6 +336,7 @@ def unified_attention( HEAD_SIZE_PADDED=triton.next_power_of_2(head_size), USE_ALIBI_SLOPES=use_alibi_slopes, USE_SOFTCAP=(softcap > 0), + USE_SINKS=(sinks is not None), SLIDING_WINDOW=(1 + window_size[0]), stride_k_cache_0=k.stride(0), stride_k_cache_1=k.stride(1), diff --git a/model_executor/layers/fused_moe/fused_moe.py b/model_executor/layers/fused_moe/fused_moe.py index 9d74c68..a16b66a 100644 --- a/model_executor/layers/fused_moe/fused_moe.py +++ b/model_executor/layers/fused_moe/fused_moe.py @@ -275,6 +275,7 @@ def fused_moe_kernel( a_ptr, b_ptr, c_ptr, + b_bias_ptr, a_scale_ptr, b_scale_ptr, topk_weights_ptr, @@ -303,6 +304,8 @@ def fused_moe_kernel( stride_bse, stride_bsk, stride_bsn, + stride_bbe, # bias expert stride + stride_bbn, # bias N stride # Block size for block-wise quantization group_n: tl.constexpr, group_k: tl.constexpr, @@ -321,6 +324,7 @@ def fused_moe_kernel( use_int8_w8a8: tl.constexpr, use_int8_w8a16: tl.constexpr, per_channel_quant: tl.constexpr, + HAS_BIAS: tl.constexpr, UPGRADE: tl.constexpr, UPGRADE_A_OFFS: tl.constexpr, UPGRADE_B_OFFS: tl.constexpr, @@ -447,6 +451,10 @@ def fused_moe_kernel( else: a_scale = tl.load(a_scale_ptr) b_scale = tl.load(b_scale_ptr + off_experts) + if HAS_BIAS: + # bias shape: [num_experts, N] + bias_ptrs = b_bias_ptr + off_experts * stride_bbe + offs_bn * stride_bbn + bias = tl.load(bias_ptrs, mask=(offs_bn < N), other=0.0) # ----------------------------------------------------------- # Iterate to compute a block of the C matrix. @@ -494,7 +502,8 @@ def fused_moe_kernel( # Advance the ptrs to the next K block. a_ptrs += BLOCK_SIZE_K * stride_ak * SPLIT_K b_ptrs += BLOCK_SIZE_K * stride_bk * SPLIT_K - + if HAS_BIAS: + accumulator = accumulator + bias[None, :] if MUL_ROUTED_WEIGHT: moe_weight = tl.load(topk_weights_ptr + offs_token, mask=token_mask, @@ -548,7 +557,8 @@ def invoke_fused_moe_kernel(A: torch.Tensor, use_int4_w4a16: bool, orig_acc_dtype: torch.dtype, per_channel_quant: bool, - block_shape: Optional[list[int]] = None) -> None: + block_shape: Optional[list[int]] = None, + B_bias: Optional[torch.Tensor] = None) -> None: assert topk_weights is not None or not mul_routed_weight assert topk_weights is None or topk_weights.stride(1) == 1 assert sorted_token_ids.stride(0) == 1 @@ -580,7 +590,7 @@ def invoke_fused_moe_kernel(A: torch.Tensor, A.shape[0] * top_k * config['BLOCK_SIZE_M']) grid = lambda META: (triton.cdiv(EM, META['BLOCK_SIZE_M']) * triton.cdiv( B.shape[1], META['BLOCK_SIZE_N']), META['SPLIT_K']) - + HAS_BIAS = B_bias is not None if (use_int8_w8a16 or use_int4_w4a16) and \ block_shape is not None and block_shape[1] > 0: assert B_scale is not None and B_scale.ndim == 3 @@ -592,19 +602,19 @@ def invoke_fused_moe_kernel(A: torch.Tensor, num_experts=B.shape[0], bit=4 if use_int4_w4a16 else 8) # TODO: missing config for BLOCK_SIZE_K - # config = config.copy() - # config.update( - # get_moe_wna16_block_config(config=config, - # use_moe_wna16_cuda=use_moe_wna16_cuda, - # num_valid_tokens=num_tokens, - # size_k=A.shape[1], - # size_n=B.shape[1], - # num_experts=B.shape[1], - # group_size=block_shape[1], - # real_top_k=top_k, - # block_size_m=config["BLOCK_SIZE_M"])) + config = config.copy() + config.update( + get_moe_wna16_block_config(config=config, + use_moe_wna16_cuda=use_moe_wna16_cuda, + num_valid_tokens=num_tokens, + size_k=A.shape[1], + size_n=B.shape[1], + num_experts=B.shape[1], + group_size=block_shape[1], + real_top_k=top_k, + block_size_m=config["BLOCK_SIZE_M"])) - if False and use_moe_wna16_cuda: + if use_moe_wna16_cuda: bit = 4 if use_int4_w4a16 else 8 ops.moe_wna16_gemm(A, C, B, B_scale, B_zp, topk_weights if mul_routed_weight else None, @@ -661,6 +671,7 @@ def invoke_fused_moe_kernel(A: torch.Tensor, A, B, C, + B_bias, A_scale, B_scale, topk_weights, @@ -689,6 +700,8 @@ def invoke_fused_moe_kernel(A: torch.Tensor, if B_scale is not None and B_scale.ndim == 3 else 0, B_scale.stride(1) if B_scale is not None and B_scale.ndim >= 2 else 0, + B_bias.stride(0) if B_bias is not None else 0, + B_bias.stride(1) if B_bias is not None else 0, 0 if block_shape is None else block_shape[0], 0 if block_shape is None else block_shape[1], MUL_ROUTED_WEIGHT=mul_routed_weight, @@ -699,6 +712,7 @@ def invoke_fused_moe_kernel(A: torch.Tensor, use_int8_w8a8=use_int8_w8a8, use_int8_w8a16=use_int8_w8a16, per_channel_quant=per_channel_quant, + HAS_BIAS=HAS_BIAS, BLOCK_SIZE_K=BLOCK_SIZE_K, FAST_F32_TO_BF16 = True, **config, @@ -1103,13 +1117,15 @@ def inplace_fused_experts(hidden_states: torch.Tensor, w2_zp: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None, - block_shape: Optional[List[int]] = None) -> None: + block_shape: Optional[List[int]] = None, + w1_bias: Optional[torch.Tensor] = None, + w2_bias: Optional[torch.Tensor] = None) -> None: fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids, True, activation, apply_router_weight_on_input, use_fp8_w8a8, use_int8_w8a8, use_int8_w8a16, use_int4_w4a16, per_channel_quant, global_num_experts, expert_map, w1_scale, w2_scale, w1_zp, w2_zp, a1_scale, a2_scale, - block_shape) + block_shape, w1_bias, w2_bias) def inplace_fused_experts_fake( @@ -1133,7 +1149,9 @@ def inplace_fused_experts_fake( w2_zp: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None, - block_shape: Optional[List[int]] = None) -> None: + block_shape: Optional[List[int]] = None, + w1_bias: Optional[torch.Tensor] = None, + w2_bias: Optional[torch.Tensor] = None) -> None: pass @@ -1167,14 +1185,16 @@ def outplace_fused_experts( w2_zp: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None, - block_shape: Optional[List[int]] = None) -> torch.Tensor: + block_shape: Optional[List[int]] = None, + w1_bias: Optional[torch.Tensor] = None, + w2_bias: Optional[torch.Tensor] = None) -> torch.Tensor: return fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids, False, activation, apply_router_weight_on_input, use_fp8_w8a8, use_int8_w8a8, use_int8_w8a16, use_int4_w4a16, per_channel_quant, global_num_experts, expert_map, w1_scale, w2_scale, w1_zp, w2_zp, a1_scale, a2_scale, - block_shape) + block_shape, w1_bias, w2_bias) def outplace_fused_experts_fake( @@ -1197,7 +1217,9 @@ def outplace_fused_experts_fake( w2_zp: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None, - block_shape: Optional[List[int]] = None) -> torch.Tensor: + block_shape: Optional[List[int]] = None, + w1_bias: Optional[torch.Tensor] = None, + w2_bias: Optional[torch.Tensor] = None) -> torch.Tensor: return torch.empty_like(hidden_states) @@ -1248,7 +1270,9 @@ def fused_experts(hidden_states: torch.Tensor, a1_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None, block_shape: Optional[list[int]] = None, - allow_deep_gemm: bool = False) -> torch.Tensor: + allow_deep_gemm: bool = False, + w1_bias: Optional[torch.Tensor] = None, + w2_bias: Optional[torch.Tensor] = None) -> torch.Tensor: # For now, disable DeepGemm for small N (<= 512) until better # permute/unpermute ops are available. N = w1.shape[1] @@ -1293,7 +1317,10 @@ def fused_experts(hidden_states: torch.Tensor, w2_zp=w2_zp, a1_scale=a1_scale, a2_scale=a2_scale, - block_shape=block_shape) + block_shape=block_shape, + w1_bias=w1_bias, + w2_bias=w2_bias, + ) def fused_experts_impl( @@ -1319,6 +1346,8 @@ def fused_experts_impl( a1_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None, block_shape: Optional[list[int]] = None, + w1_bias: Optional[torch.Tensor] = None, + w2_bias: Optional[torch.Tensor] = None, ) -> torch.Tensor: # Check constraints. if use_int4_w4a16: @@ -1498,7 +1527,19 @@ def fused_experts_impl( use_int4_w4a16=use_int4_w4a16, orig_acc_dtype=hidden_states.dtype, per_channel_quant=per_channel_quant, - block_shape=block_shape) + block_shape=block_shape, + B_bias=w1_bias) + + # TODO fused kernel + def swiglu_oai(gate_up): + alpha = 1.702 + limit = 7.0 + gate, up = gate_up[..., ::2], gate_up[..., 1::2] + gate = gate.clamp(min=None, max=limit) + up = up.clamp(min=-limit, max=limit) + glu = gate * torch.sigmoid(gate * alpha) + gated_output = (up + 1) * glu + return gated_output if activation == "silu": torch.ops._C.silu_and_mul(intermediate_cache2, @@ -1506,6 +1547,8 @@ def fused_experts_impl( elif activation == "gelu": torch.ops._C.gelu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N)) + elif activation == "swiglu_oai": + intermediate_cache2 = swiglu_oai(intermediate_cache1.view(-1, N)) else: raise ValueError(f"Unsupported FusedMoe activation: {activation}") @@ -1543,7 +1586,8 @@ def fused_experts_impl( use_int4_w4a16=use_int4_w4a16, orig_acc_dtype=hidden_states.dtype, per_channel_quant=per_channel_quant, - block_shape=block_shape) + block_shape=block_shape, + B_bias=w2_bias) ops.moe_sum(intermediate_cache3.view(*intermediate_cache3.shape), out_hidden_states[begin_chunk_idx:end_chunk_idx]) @@ -1578,6 +1622,8 @@ def fused_moe( a1_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None, block_shape: Optional[list[int]] = None, + w1_bias: Optional[torch.Tensor] = None, + w2_bias: Optional[torch.Tensor] = None, ) -> torch.Tensor: """ This function computes a Mixture of Experts (MoE) layer using two sets of @@ -1661,7 +1707,9 @@ def fused_moe( w2_zp=w2_zp, a1_scale=a1_scale, a2_scale=a2_scale, - block_shape=block_shape) + block_shape=block_shape, + w1_bias=w1_bias, + w2_bias=w2_bias) class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute): @@ -1805,7 +1853,9 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute): use_int8_w8a16=self.use_int8_w8a16, use_int4_w4a16=self.use_int4_w4a16, per_channel_quant=self.per_channel_quant, - block_shape=self.block_shape) + block_shape=self.block_shape, + B_bias=None # TODO support B_bias + ) self.activation(activation, intermediate_cache2, intermediate_cache1.view(-1, N)) @@ -1835,7 +1885,9 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute): use_int8_w8a16=self.use_int8_w8a16, use_int4_w4a16=self.use_int4_w4a16, per_channel_quant=self.per_channel_quant, - block_shape=self.block_shape) + block_shape=self.block_shape, + B_bias=None # TODO support B_bias + ) return intermediate_cache3 diff --git a/model_executor/layers/fused_moe/layer.py b/model_executor/layers/fused_moe/layer.py index 074e690..337e0c0 100644 --- a/model_executor/layers/fused_moe/layer.py +++ b/model_executor/layers/fused_moe/layer.py @@ -226,6 +226,8 @@ class MoEConfig: max_num_tokens: int = MOE_DP_CHUNK_SIZE + has_bias: bool = False + @property def tp_size(self): return self.moe_parallel_config.tp_size @@ -443,6 +445,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): self.fused_experts = fused_experts # type: ignore self.topk_indices_dtype = None self.moe = moe + self.has_bias = self.moe.has_bias self.rocm_aiter_moe_enabled = is_rocm_aiter_moe_enabled() if self.rocm_aiter_moe_enabled: @@ -502,6 +505,14 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): requires_grad=False) layer.register_parameter("w13_weight", w13_weight) set_weight_attrs(w13_weight, extra_weight_attrs) + if self.has_bias: + w13_bias = torch.nn.Parameter(torch.zeros( + num_experts, + 2 * intermediate_size_per_partition, + dtype=params_dtype), + requires_grad=False) + layer.register_parameter("w13_bias", w13_bias) + set_weight_attrs(w13_bias, extra_weight_attrs) # down_proj (row parallel) w2_weight = torch.nn.Parameter(torch.empty( @@ -512,6 +523,13 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): requires_grad=False) layer.register_parameter("w2_weight", w2_weight) set_weight_attrs(w2_weight, extra_weight_attrs) + if self.has_bias: + w2_bias = torch.nn.Parameter(torch.zeros(num_experts, + hidden_size, + dtype=params_dtype), + requires_grad=False) + layer.register_parameter("w2_bias", w2_bias) + set_weight_attrs(w2_bias, extra_weight_attrs) def _maybe_pad_weight(self, weight: torch.Tensor) -> torch.Tensor: # Pad the weight tensor. This is an optimization on ROCm platform, which @@ -634,6 +652,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): hidden_states=x, w1=layer.w13_weight, w2=layer.w2_weight, + w1_bias=layer.w13_bias if self.has_bias else None, + w2_bias=layer.w2_bias if self.has_bias else None, topk_weights=topk_weights, topk_ids=topk_ids, inplace=True, @@ -840,6 +860,7 @@ class FusedMoE(torch.nn.Module): e_score_correction_bias: Optional[torch.Tensor] = None, apply_router_weight_on_input: bool = False, activation: str = "silu", + has_bias: bool = False, ): super().__init__() if params_dtype is None: @@ -920,6 +941,7 @@ class FusedMoE(torch.nn.Module): in_dtype=params_dtype, quant_dtype=quant_dtype, max_num_tokens=MOE_DP_CHUNK_SIZE, + has_bias=has_bias, ) self.moe_config = moe self.quant_config = quant_config diff --git a/model_executor/models/gpt_oss.py b/model_executor/models/gpt_oss.py new file mode 100644 index 0000000..7d4293a --- /dev/null +++ b/model_executor/models/gpt_oss.py @@ -0,0 +1,618 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from collections.abc import Iterable +from typing import Optional + +import torch +import torch.distributed as dist +from torch import nn +from transformers import GptOssConfig + +from vllm import envs +from vllm.attention import Attention, AttentionType +from vllm.compilation.decorators import support_torch_compile +from vllm.config import CacheConfig, VllmConfig +from vllm.distributed import (get_ep_group, get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size) +from vllm.model_executor.layers.fused_moe import FusedMoE +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (QKVParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.vocab_parallel_embedding import ( + ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors +from vllm.utils import cdiv + +from .utils import extract_layer_index, maybe_prefix + + +class OAIAttention(nn.Module): + + def __init__( + self, + config: GptOssConfig, + quant_config: Optional[QuantizationConfig] = None, + cache_config: Optional[CacheConfig] = None, + prefix: str = "", + ): + super().__init__() + self.layer_idx = extract_layer_index(prefix) + self.head_dim = config.head_dim + self.num_attention_heads = config.num_attention_heads + self.num_key_value_heads = config.num_key_value_heads + self.hidden_size = config.hidden_size + + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=config.max_position_embeddings, + base=config.rope_theta, + dtype=torch.float32, + rope_scaling={ + "rope_type": + "yarn", + "factor": + config.rope_scaling["factor"], + "original_max_position_embeddings": + config.rope_scaling["original_max_position_embeddings"], + "beta_fast": + config.rope_scaling["beta_fast"], + "beta_slow": + config.rope_scaling["beta_slow"], + }, + is_neox_style=True, + ) + + tp_size = get_tensor_model_parallel_world_size() + + # attention_sink_dtype = (torch.float32 if envs.VLLM_USE_TRTLLM_ATTENTION + # else torch.bfloat16) + attention_sink_dtype = torch.bfloat16 + self.sinks = torch.nn.Parameter( + torch.empty(config.num_attention_heads // tp_size, + dtype=attention_sink_dtype, + requires_grad=False)) + + self.norm = RMSNorm(config.hidden_size, eps=1e-5) + + self.q_size = self.num_attention_heads * self.head_dim // tp_size + self.kv_size = self.num_key_value_heads * self.head_dim // tp_size + self.scaling = self.head_dim**-0.5 + self.rope_theta = config.rope_theta + + self.qkv = QKVParallelLinear( + hidden_size=self.hidden_size, + head_size=self.head_dim, + total_num_heads=self.num_attention_heads, + total_num_kv_heads=self.num_key_value_heads, + quant_config=quant_config, + prefix=f"{prefix}.qkv_proj", + ) + + self.o_proj = RowParallelLinear( + input_size=self.num_attention_heads * self.head_dim, + output_size=self.hidden_size, + quant_config=quant_config, + prefix=f"{prefix}.o_proj", + ) + + self.num_local_attention_heads = config.num_attention_heads // tp_size + self.num_local_key_value_heads = config.num_key_value_heads // tp_size + + # Only apply sliding window to every other layer + sliding_window = (config.sliding_window if self.layer_idx % + 2 == 0 else None) + self.attn = Attention( + self.num_local_attention_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_local_key_value_heads, + cache_config=cache_config, + quant_config=quant_config, + per_layer_sliding_window=sliding_window, + attn_type=AttentionType.DECODER, + prefix=f"{prefix}.attn", + sinks=self.sinks, + ) + + def forward(self, hidden_states: torch.Tensor, + positions: torch.Tensor) -> torch.Tensor: + t = self.norm(hidden_states) + + qkv, _ = self.qkv(t) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q, k = self.rotary_emb(positions, q, k) + v = v.contiguous() + attn_output = self.attn(q, k, v) + output, _ = self.o_proj(attn_output) + + return output + hidden_states + + +class MLPBlock(torch.nn.Module): + + def __init__( + self, + config: GptOssConfig, + layer_idx: int, + quant_config: QuantizationConfig, + prefix: str = "", + ): + super().__init__() + self.layer_idx = layer_idx + self.num_experts = config.num_local_experts + self.experts_per_token = config.num_experts_per_tok + self.world_size = dist.get_world_size() if dist.is_initialized() else 1 + self.norm = RMSNorm(config.hidden_size, eps=1e-5) + self.router = torch.nn.Linear(config.hidden_size, + config.num_local_experts, + dtype=torch.bfloat16) + assert config.intermediate_size % self.world_size == 0 + self.experts = FusedMoE(num_experts=config.num_local_experts, + top_k=config.num_experts_per_tok, + hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, + reduce_results=True, + renormalize=True, + quant_config=quant_config, + prefix=f"{prefix}.experts", + apply_router_weight_on_input=False, + has_bias=True, + activation="swiglu_oai") + + def forward(self, x: torch.Tensor) -> torch.Tensor: + t = self.norm(x) + g = self.router(t) + t = self.experts(hidden_states=t, router_logits=g) + return x + t + + +class TransformerBlock(torch.nn.Module): + + def __init__( + self, + config: GptOssConfig, + quant_config: QuantizationConfig, + prefix: str = "", + ): + super().__init__() + self.layer_idx = extract_layer_index(prefix) + self.attn = OAIAttention(config, prefix=f"{prefix}.attn") + self.mlp = MLPBlock(config, + self.layer_idx, + quant_config=quant_config, + prefix=f"{prefix}.mlp") + + def forward(self, hidden_states: torch.Tensor, + positions: torch.Tensor) -> torch.Tensor: + attn_output = self.attn(hidden_states, positions) + output = self.mlp(attn_output) + return output + + +@support_torch_compile +class GptOssModel(nn.Module): + + def __init__( + self, + *, + vllm_config: VllmConfig, + prefix: str = "", + ): + super().__init__() + self.config = vllm_config.model_config.hf_config + self.quant_config = vllm_config.quant_config + self.config.hidden_size = self.config.hidden_size + self.embedding = VocabParallelEmbedding( + self.config.vocab_size, + self.config.hidden_size, + ) + self.layers = torch.nn.ModuleList([ + TransformerBlock( + self.config, + quant_config=self.quant_config, + prefix=maybe_prefix(prefix, f"block.{layer_idx}"), + ) for layer_idx in range(self.config.num_hidden_layers) + ]) + self.norm = RMSNorm(self.config.hidden_size, eps=1e-5) + + def forward(self, input_ids: torch.Tensor, + positions: torch.Tensor) -> torch.Tensor: + x = self.embedding(input_ids) + for layer in self.layers: + x = layer(x, positions) + x = self.norm(x) + return x + + +class GptOssForCausalLM(nn.Module): + + def __init__( + self, + vllm_config: VllmConfig, + prefix: str = "", + ): + super().__init__() + self.vllm_config = vllm_config + self.model_config = vllm_config.model_config.hf_config + self.model = GptOssModel( + vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model"), + ) + self.lm_head = ParallelLMHead( + self.model_config.vocab_size, + self.model_config.hidden_size, + ) + self.logits_processor = LogitsProcessor(self.model_config.vocab_size) + + def forward(self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None) -> torch.Tensor: + assert intermediate_tensors is None + assert inputs_embeds is None + return self.model(input_ids, positions) + + def compute_logits(self, hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata) -> torch.Tensor: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def _load_weights_mxfp4( + self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + rename_mapping = { + "self_attn": "attn", + "input_layernorm.weight": "attn.norm.weight", + "post_attention_layernorm.weight": "mlp.norm.weight", + "embed_tokens": "embedding", + } + + def maybe_rename(name: str) -> str: + for remap_name, new_name in rename_mapping.items(): + if remap_name in name: + return name.replace(remap_name, new_name) + return name + + params_dict = dict(self.named_parameters()) + loaded_params: set[str] = set() + mxfp4_block = 32 + + tp_rank = get_tensor_model_parallel_rank() + tp_size = get_tensor_model_parallel_world_size() + intermediate_size = self.model_config.intermediate_size + intermediate_size_block = intermediate_size // mxfp4_block + per_rank_intermediate_size_block = cdiv(intermediate_size_block, + tp_size) + per_rank_intermediate_size = (per_rank_intermediate_size_block * + mxfp4_block) + + # Calculate common slicing bounds for current rank + tp_rank_start = tp_rank * per_rank_intermediate_size + tp_rank_end = min((tp_rank + 1) * per_rank_intermediate_size, + intermediate_size) + + # Attention heads per rank + heads_per_rank = self.model_config.num_attention_heads // tp_size + head_start = tp_rank * heads_per_rank + + use_ep = self.vllm_config.parallel_config.enable_expert_parallel + ep_size = get_ep_group().world_size + ep_rank = get_ep_group().rank + num_experts = self.model_config.num_local_experts + experts_per_rank = num_experts // ep_size + ep_rank_start = ep_rank * experts_per_rank + ep_rank_end = (ep_rank + 1) * experts_per_rank + + for name, weight in weights: + # FIXME(woosuk): Remove this after testing. + weight = weight.cuda() + + if "gate_up_proj_blocks" in name: + # Handle MLP gate and up projection weights + new_name = name.replace("gate_up_proj_blocks", "w13_weight") + + # flat weight from (E, 2 * N, block_size, entry_per_block) + # to (E, 2 * N, -1), shouldn't trigger copy for contiguous + weight = weight.view(num_experts, 2 * intermediate_size, + -1).contiguous() + + # Extract gate and up projection parts + # since the weight is shuffled, we can slice directly + if use_ep: + narrow_weight = weight[ep_rank_start:ep_rank_end, ...] + else: + narrow_weight = weight[:, + 2 * tp_rank_start:2 * tp_rank_end, + ...] + + param = params_dict[new_name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, + narrow_weight, + weight_name=new_name, + shard_id=None, + expert_id=None) + loaded_params.add(new_name) + + elif "down_proj_blocks" in name: + # Handle MLP down projection weights + new_name = name.replace("down_proj_blocks", "w2_weight") + # same flatten here, but since 2 mx4 value are packed in 1 + # uint8, divide by 2 + weight = weight.view(num_experts, -1, + intermediate_size // 2).contiguous() + if use_ep: + narrow_weight = weight[ep_rank_start:ep_rank_end, ...] + else: + narrow_weight = weight[..., + tp_rank_start // 2:tp_rank_end // 2] + + param = params_dict[new_name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, + narrow_weight, + weight_name=new_name, + shard_id=None, + expert_id=None) + loaded_params.add(new_name) + + elif "gate_up_proj_scales" in name: + # Handle MLP gate and up projection weights scale + new_name = name.replace("gate_up_proj_scales", + "w13_weight_scale") + if use_ep: + narrow_weight = weight[ep_rank_start:ep_rank_end, ...] + else: + narrow_weight = weight[:, + 2 * tp_rank_start:2 * tp_rank_end, + ...] + + param = params_dict[new_name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, + narrow_weight, + weight_name=new_name, + shard_id=None, + expert_id=None) + loaded_params.add(new_name) + + elif "down_proj_scales" in name: + # Handle MLP down projection weights + new_name = name.replace("down_proj_scales", "w2_weight_scale") + if use_ep: + narrow_weight = weight[ep_rank_start:ep_rank_end, ...] + else: + narrow_weight = weight[..., tp_rank_start // + mxfp4_block:tp_rank_end // + mxfp4_block] + + param = params_dict[new_name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, + narrow_weight, + weight_name=new_name, + shard_id=None, + expert_id=None) + loaded_params.add(new_name) + elif "gate_up_proj_bias" in name: + # Handle MLP gate and up projection biases + new_name = name.replace("gate_up_proj_bias", "w13_bias") + + # Extract gate and up projection bias parts + if use_ep: + narrow_weight = weight[ep_rank_start:ep_rank_end, ...] + else: + narrow_weight = weight[:, + 2 * tp_rank_start:2 * tp_rank_end] + + param = params_dict[new_name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, + narrow_weight, + weight_name=new_name, + shard_id=None, + expert_id=None) + loaded_params.add(new_name) + + elif "down_proj_bias" in name: + # Handle MLP down projection bias + new_name = name.replace("down_proj_bias", "w2_bias") + param = params_dict[new_name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + if use_ep: + weight = weight[ep_rank_start:ep_rank_end, ...] + else: + # (only load on rank 0 to avoid duplication) + if tp_rank != 0: + weight.zero_() + weight_loader(param, + weight, + weight_name=new_name, + shard_id=None, + expert_id=None) + loaded_params.add(new_name) + elif "sinks" in name: + # Handle attention sinks (distributed across ranks) + name = name.replace("self_attn", "attn") + param = params_dict[name] + narrow_weight = weight.narrow(0, head_start, heads_per_rank) + param.data.copy_(narrow_weight) + loaded_params.add(name) + elif "q_proj" in name or "k_proj" in name or "v_proj" in name: + shard_id = ("q" if "q_proj" in name else + "k" if "k_proj" in name else "v") + name = name.replace("self_attn", "attn") + param_name = name.replace(f"{shard_id}_proj", "qkv") + param = params_dict[param_name] + weight_loader = param.weight_loader + weight_loader(param, weight, loaded_shard_id=shard_id) + loaded_params.add(param_name) + else: + # Handle all other weights with potential renaming + renamed_name = maybe_rename(name) + if renamed_name not in params_dict: + continue + param = params_dict[renamed_name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, weight) + loaded_params.add(renamed_name) + + return loaded_params + + def _load_weights_other( + self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + rename_mapping = { + "self_attn": "attn", + "input_layernorm.weight": "attn.norm.weight", + "post_attention_layernorm.weight": "mlp.norm.weight", + "embed_tokens": "embedding", + } + + def maybe_rename(name: str) -> str: + for remap_name, new_name in rename_mapping.items(): + if remap_name in name: + return name.replace(remap_name, new_name) + return name + + params_dict = dict(self.named_parameters()) + loaded_params: set[str] = set() + + tp_rank = get_tensor_model_parallel_rank() + tp_size = get_tensor_model_parallel_world_size() + intermediate_size = self.model_config.intermediate_size + + per_rank_intermediate_size = cdiv(intermediate_size, tp_size) + # Calculate common slicing bounds for current rank + tp_rank_start = tp_rank * per_rank_intermediate_size + tp_rank_end = min((tp_rank + 1) * per_rank_intermediate_size, + intermediate_size) + + # Attention heads per rank + heads_per_rank = self.model_config.num_attention_heads // tp_size + head_start = tp_rank * heads_per_rank + + use_ep = self.vllm_config.parallel_config.enable_expert_parallel + ep_size = get_ep_group().world_size + ep_rank = get_ep_group().rank + num_experts = self.model_config.num_local_experts + experts_per_rank = num_experts // ep_size + ep_rank_start = ep_rank * experts_per_rank + ep_rank_end = (ep_rank + 1) * experts_per_rank + + for name, weight in weights: + if ".experts.gate_up_proj" in name and "bias" not in name: + # Handle MLP gate and up projection weights + new_name = name.replace(".experts.gate_up_proj", + ".experts.w13_weight") + + # Extract gate and up projection parts + # since the weight is shuffled, we can slice directly + if use_ep: + narrow_weight = weight[ep_rank_start:ep_rank_end, ...] + else: + narrow_weight = weight[:, :, + 2 * tp_rank_start:2 * tp_rank_end] + + narrow_weight = narrow_weight.permute(0, 2, 1).contiguous() + param = params_dict[new_name] + + param.copy_(narrow_weight) + loaded_params.add(new_name) + + elif ".experts.down_proj" in name and "bias" not in name: + # Handle MLP down projection weights + new_name = name.replace(".experts.down_proj", + ".experts.w2_weight") + + if use_ep: + narrow_weight = weight[ep_rank_start:ep_rank_end, ...] + else: + narrow_weight = weight[:, tp_rank_start:tp_rank_end, :] + narrow_weight = narrow_weight.permute(0, 2, 1).contiguous() + param = params_dict[new_name] + + param.copy_(narrow_weight) + loaded_params.add(new_name) + + elif "gate_up_proj_bias" in name: + # Handle MLP gate and up projection biases + new_name = name.replace("gate_up_proj_bias", "w13_bias") + + # Extract gate and up projection bias parts + if use_ep: + narrow_weight = weight[ep_rank_start:ep_rank_end, ...] + else: + narrow_weight = weight[:, + 2 * tp_rank_start:2 * tp_rank_end] + + param = params_dict[new_name] + + param.copy_(narrow_weight) + loaded_params.add(new_name) + + elif "down_proj_bias" in name: + # Handle MLP down projection bias + new_name = name.replace("down_proj_bias", "w2_bias") + + if use_ep: + weight = weight[ep_rank_start:ep_rank_end, ...] + else: + # (only load on rank 0 to avoid duplication) + if tp_rank != 0: + weight.zero_() + param = params_dict[new_name] + param.copy_(weight) + loaded_params.add(new_name) + elif "sinks" in name: + # Handle attention sinks (distributed across ranks) + name = name.replace("self_attn", "attn") + param = params_dict[name] + narrow_weight = weight.narrow(0, head_start, heads_per_rank) + param.data.copy_(narrow_weight) + loaded_params.add(name) + elif "q_proj" in name or "k_proj" in name or "v_proj" in name: + shard_id = ("q" if "q_proj" in name else + "k" if "k_proj" in name else "v") + name = name.replace("self_attn", "attn") + param_name = name.replace(f"{shard_id}_proj", "qkv") + param = params_dict[param_name] + weight_loader = param.weight_loader + weight_loader(param, weight, loaded_shard_id=shard_id) + loaded_params.add(param_name) + else: + # Handle all other weights with potential renaming + + renamed_name = maybe_rename(name) + if renamed_name not in params_dict: + continue + param = params_dict[renamed_name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, weight) + loaded_params.add(renamed_name) + + return loaded_params + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + quant_method = (self.model_config.quantization_config['quant_method'] + if hasattr(self.model_config, "quantization_config") + else None) + if quant_method == "mxfp4": + return self._load_weights_mxfp4(weights) + else: + return self._load_weights_other(weights) diff --git a/model_executor/models/registry.py b/model_executor/models/registry.py index 6703c35..d8152fc 100644 --- a/model_executor/models/registry.py +++ b/model_executor/models/registry.py @@ -61,6 +61,7 @@ _TEXT_GENERATION_MODELS = { "Gemma3ForCausalLM": ("gemma3", "Gemma3ForCausalLM"), "GlmForCausalLM": ("glm", "GlmForCausalLM"), "Glm4ForCausalLM": ("glm4", "Glm4ForCausalLM"), + "GptOssForCausalLM": ("gpt_oss", "GptOssForCausalLM"), "GPT2LMHeadModel": ("gpt2", "GPT2LMHeadModel"), "GPTBigCodeForCausalLM": ("gpt_bigcode", "GPTBigCodeForCausalLM"), "GPTJForCausalLM": ("gpt_j", "GPTJForCausalLM"), diff --git a/platforms/rocm.py b/platforms/rocm.py index a929366..da94f00 100644 --- a/platforms/rocm.py +++ b/platforms/rocm.py @@ -126,7 +126,8 @@ def use_rocm_custom_paged_attention( max_seq_len: int, sliding_window: int, kv_cache_dtype: str, - alibi_slopes: Optional[torch.Tensor] = None) -> bool: + alibi_slopes: Optional[torch.Tensor] = None, + sinks: Optional[torch.Tensor] = None) -> bool: GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName ON_GFX9 = any(arch in GPU_ARCH for arch in ["gfx90a", "gfx942", "gfx950"]) @@ -143,7 +144,7 @@ def use_rocm_custom_paged_attention( and (gqa_ratio >= 1 and gqa_ratio <= 16) and max_seq_len <= 32768 and (envs.VLLM_ROCM_CUSTOM_PAGED_ATTN) and not (envs.VLLM_ROCM_USE_AITER_PAGED_ATTN - and envs.VLLM_ROCM_USE_AITER)) + and envs.VLLM_ROCM_USE_AITER) and sinks is None) else: return (ON_GFX11_GFX12 and (not envs.VLLM_USE_V1 or sliding_window == 0 @@ -153,7 +154,7 @@ def use_rocm_custom_paged_attention( and (gqa_ratio >= 3 and gqa_ratio <= 16) and max_seq_len <= 32768 and alibi_slopes is None and kv_cache_dtype == "auto" - and envs.VLLM_ROCM_CUSTOM_PAGED_ATTN) + and envs.VLLM_ROCM_CUSTOM_PAGED_ATTN and sinks is None) class RocmPlatform(Platform): diff --git a/transformers_utils/configs/ovis.py b/transformers_utils/configs/ovis.py index c2728f0..874aa1c 100644 --- a/transformers_utils/configs/ovis.py +++ b/transformers_utils/configs/ovis.py @@ -73,7 +73,7 @@ IMAGE_TOKEN = "" IMAGE_ATOM_ID = -300 IMAGE_INDICATOR_IDS = [-301, -302, -303, -304, -305] -AutoConfig.register("aimv2", AIMv2Config) +AutoConfig.register("aimv2", AIMv2Config, exist_ok=True) # ---------------------------------------------------------------------- diff --git a/v1/attention/backends/triton_attn.py b/v1/attention/backends/triton_attn.py index 5db592b..6a7c704 100644 --- a/v1/attention/backends/triton_attn.py +++ b/v1/attention/backends/triton_attn.py @@ -90,6 +90,7 @@ class TritonAttentionImpl(AttentionImpl): attn_type: AttentionType = AttentionType.DECODER, kv_sharing_target_layer_name: Optional[int] = None, use_irope: bool = False, + sinks: Optional[torch.Tensor] = None, ) -> None: if blocksparse_params is not None: raise ValueError( @@ -132,6 +133,13 @@ class TritonAttentionImpl(AttentionImpl): self.fp8_dtype = current_platform.fp8_dtype() self.force_prefill_decode_attn = \ envs.VLLM_V1_USE_PREFILL_DECODE_ATTENTION + + self.sinks = sinks + if sinks is not None: + assert sinks.shape[0] == num_heads, ( + "Sinks must have the same number of heads as the number of " + f"heads in the layer. Sinks shape: {sinks.shape}, " + f"num_heads: {num_heads}.") def forward( self, @@ -257,7 +265,8 @@ class TritonAttentionImpl(AttentionImpl): v_scale=layer._v_scale, alibi_slopes=self.alibi_slopes, sliding_window=self.sliding_window[0], - sm_scale=self.scale) + sm_scale=self.scale, + sinks=self.sinks) else: descale_shape = (cu_seqlens_q.shape[0] - 1, key.shape[1]) @@ -280,6 +289,7 @@ class TritonAttentionImpl(AttentionImpl): q_descale=None, # Not supported k_descale=layer._k_scale.expand(descale_shape), v_descale=layer._v_scale.expand(descale_shape), + sinks=self.sinks, ) return output