diff --git a/python/sglang/srt/layers/attention/triton_backend.py b/python/sglang/srt/layers/attention/triton_backend.py index c46c8cd4d..469e4fde3 100644 --- a/python/sglang/srt/layers/attention/triton_backend.py +++ b/python/sglang/srt/layers/attention/triton_backend.py @@ -88,6 +88,7 @@ class TritonAttnBackend(AttentionBackend): self.window_kv_indptr = torch.zeros_like(kv_indptr_buf) self.req_to_token = model_runner.req_to_token_pool.req_to_token + self.token_to_kv_pool_allocator = model_runner.token_to_kv_pool_allocator if not self.skip_prefill: self.qo_indptr = torch.zeros( @@ -197,6 +198,7 @@ class TritonAttnBackend(AttentionBackend): forward_batch.req_pool_indices, bs, self.device, + self.token_to_kv_pool_allocator, ) ) window_num_kv_splits = torch.empty( @@ -225,7 +227,6 @@ class TritonAttnBackend(AttentionBackend): mask_indptr = None max_extend_len = None elif forward_batch.forward_mode.is_target_verify(): - # TODO: Support sliding window in spec inference bs = len(forward_batch.req_pool_indices) qo_indptr = torch.arange( 0, @@ -250,6 +251,20 @@ class TritonAttnBackend(AttentionBackend): self.req_to_token.stride(0), ) + if self.sliding_window_size is not None and self.sliding_window_size > 0: + window_kv_indptr, window_kv_indices, window_kv_lens = ( + update_sliding_window_buffer( + self.window_kv_indptr, + self.req_to_token, + self.sliding_window_size, + forward_batch.seq_lens, + forward_batch.req_pool_indices, + bs, + self.device, + self.token_to_kv_pool_allocator, + ) + ) + custom_mask = spec_info.custom_mask seq_mask_len = self.num_draft_tokens * ( forward_batch.seq_lens + self.num_draft_tokens @@ -308,6 +323,7 @@ class TritonAttnBackend(AttentionBackend): forward_batch.req_pool_indices, bs, self.device, + self.token_to_kv_pool_allocator, ) qo_indptr = self.qo_indptr @@ -423,14 +439,17 @@ class TritonAttnBackend(AttentionBackend): ): window_kv_indices = self.cuda_graph_window_kv_indices window_num_kv_splits = self.cuda_graph_window_num_kv_splits - window_kv_indptr, _ = update_sliding_window_buffer_cuda_graph( - self.window_kv_indptr, - window_kv_indices, - self.req_to_token, - self.sliding_window_size, - seq_lens[:bs], - req_pool_indices, - bs, + window_kv_indptr, window_kv_indices, _ = ( + update_sliding_window_buffer_cuda_graph( + self.window_kv_indptr, + window_kv_indices, + self.req_to_token, + self.sliding_window_size, + seq_lens[:bs], + req_pool_indices, + bs, + self.token_to_kv_pool_allocator, + ) ) else: kv_indptr, kv_indices = spec_info.kv_indptr, spec_info.kv_indices @@ -464,6 +483,22 @@ class TritonAttnBackend(AttentionBackend): self.req_to_token.stride(0), ) + if self.sliding_window_size is not None and self.sliding_window_size > 0: + window_kv_indices = self.cuda_graph_window_kv_indices + window_num_kv_splits = self.cuda_graph_window_num_kv_splits + window_kv_indptr, window_kv_indices, _ = ( + update_sliding_window_buffer_cuda_graph( + self.window_kv_indptr, + window_kv_indices, + self.req_to_token, + self.sliding_window_size, + seq_lens, + req_pool_indices, + bs, + self.token_to_kv_pool_allocator, + ) + ) + custom_mask = self.cuda_graph_custom_mask custom_mask[: spec_info.custom_mask.shape[0]] = spec_info.custom_mask seq_mask_len = self.num_draft_tokens * (seq_lens + self.num_draft_tokens) @@ -557,7 +592,7 @@ class TritonAttnBackend(AttentionBackend): ): window_num_kv_splits = self.cuda_graph_window_num_kv_splits window_kv_indices = self.cuda_graph_window_kv_indices - _, window_kv_lens = update_sliding_window_buffer_cuda_graph( + _, _, window_kv_lens = update_sliding_window_buffer_cuda_graph( self.window_kv_indptr, window_kv_indices, self.req_to_token, @@ -565,6 +600,7 @@ class TritonAttnBackend(AttentionBackend): seq_lens[:bs], req_pool_indices[:bs], bs, + self.token_to_kv_pool_allocator, ) self.get_num_kv_splits( window_num_kv_splits[:num_token], window_kv_lens[:bs] @@ -599,6 +635,19 @@ class TritonAttnBackend(AttentionBackend): kv_indices, self.req_to_token.stride(0), ) + if self.sliding_window_size is not None and self.sliding_window_size > 0: + window_num_kv_splits = self.cuda_graph_window_num_kv_splits + window_kv_indices = self.cuda_graph_window_kv_indices + _, _, window_kv_lens = update_sliding_window_buffer_cuda_graph( + self.window_kv_indptr, + window_kv_indices, + self.req_to_token, + self.sliding_window_size, + seq_lens, + req_pool_indices, + bs, + self.token_to_kv_pool_allocator, + ) custom_mask = self.cuda_graph_custom_mask custom_mask[: spec_info.custom_mask.shape[0]] = spec_info.custom_mask seq_mask_len = self.num_draft_tokens * (seq_lens + self.num_draft_tokens) @@ -637,6 +686,7 @@ class TritonAttnBackend(AttentionBackend): layer: RadixAttention, forward_batch: ForwardBatch, save_kv_cache=True, + sk=None, ): # TODO: reuse the buffer across layers if layer.qk_head_dim != layer.v_head_dim: @@ -680,7 +730,8 @@ class TritonAttnBackend(AttentionBackend): self.forward_metadata.max_extend_len, layer.scaling, layer.logit_cap, - sliding_window_size, + sliding_window_size=sliding_window_size, + sk=sk, ) return o @@ -692,6 +743,7 @@ class TritonAttnBackend(AttentionBackend): layer: RadixAttention, forward_batch: ForwardBatch, save_kv_cache=True, + sk=None, ): # During torch.compile, there is a bug in rotary_emb that causes the # output value to have a 3D tensor shape. This reshapes the output correctly. @@ -728,6 +780,7 @@ class TritonAttnBackend(AttentionBackend): self.max_kv_splits, layer.scaling, layer.logit_cap, + sk=sk, ) return o @@ -932,10 +985,11 @@ def update_sliding_window_buffer( req_pool_indices, bs, device, + token_to_kv_pool_allocator=None, ): window_kv_lens = torch.minimum( seq_lens, - torch.tensor(sliding_window_size + 1), + torch.tensor(sliding_window_size), ) window_kv_indptr[1 : bs + 1] = torch.cumsum(window_kv_lens, dim=0) window_kv_indptr = window_kv_indptr[: bs + 1] @@ -952,6 +1006,14 @@ def update_sliding_window_buffer( window_kv_indices, req_to_token.stride(0), ) + # full to swa index mapping + if hasattr(token_to_kv_pool_allocator, "translate_loc_from_full_to_swa"): + kv_last_index = window_kv_indptr[-1] + window_kv_indices[:kv_last_index] = ( + token_to_kv_pool_allocator.translate_loc_from_full_to_swa( + window_kv_indices[:kv_last_index] + ) + ) return window_kv_indptr, window_kv_indices, window_kv_lens @@ -963,10 +1025,11 @@ def update_sliding_window_buffer_cuda_graph( seq_lens, req_pool_indices, bs, + token_to_kv_pool_allocator=None, ): window_kv_lens = torch.minimum( seq_lens, - torch.tensor(sliding_window_size + 1), + torch.tensor(sliding_window_size), ) window_kv_indptr[1 : bs + 1] = torch.cumsum(window_kv_lens, dim=0) window_kv_indptr = window_kv_indptr[: bs + 1] @@ -980,4 +1043,12 @@ def update_sliding_window_buffer_cuda_graph( window_kv_indices, req_to_token.stride(0), ) - return window_kv_indptr, window_kv_lens + # full to swa index mapping + if hasattr(token_to_kv_pool_allocator, "translate_loc_from_full_to_swa"): + kv_last_index = window_kv_indptr[-1] + window_kv_indices[:kv_last_index] = ( + token_to_kv_pool_allocator.translate_loc_from_full_to_swa( + window_kv_indices[:kv_last_index] + ) + ) + return window_kv_indptr, window_kv_indices, window_kv_lens diff --git a/python/sglang/srt/layers/attention/triton_ops/decode_attention.py b/python/sglang/srt/layers/attention/triton_ops/decode_attention.py index b334d851f..5e345586e 100644 --- a/python/sglang/srt/layers/attention/triton_ops/decode_attention.py +++ b/python/sglang/srt/layers/attention/triton_ops/decode_attention.py @@ -495,6 +495,7 @@ def _fwd_kernel_stage2( O, kv_indptr, num_kv_splits, + sk_ptr, stride_mid_ob, stride_mid_oh, stride_mid_os, @@ -504,6 +505,7 @@ def _fwd_kernel_stage2( MIN_BLOCK_KV: tl.constexpr, BLOCK_DV: tl.constexpr, Lv: tl.constexpr, + HAS_SK: tl.constexpr, ): cur_batch = tl.program_id(0) cur_head = tl.program_id(1) @@ -545,6 +547,10 @@ def _fwd_kernel_stage2( e_sum = e_sum * old_scale + exp_logic e_max = n_e_max + if HAS_SK: + cur_sk = tl.load(sk_ptr + cur_head) + e_sum += tl.exp(cur_sk - e_max) + tl.store( O + cur_batch * stride_obs + cur_head * stride_oh + offs_d, acc / e_sum, @@ -561,12 +567,14 @@ def _decode_softmax_reducev_fwd( kv_indptr, num_kv_splits, max_kv_splits, + sk=None, ): batch, head_num = q.shape[0], q.shape[1] Lv = v_buffer.shape[-1] BLOCK_DV = triton.next_power_of_2(Lv) MAX_KV_SPLITS = max_kv_splits + HAS_SK = sk is not None extra_kargs = {} if _is_hip: @@ -581,6 +589,7 @@ def _decode_softmax_reducev_fwd( o, kv_indptr, num_kv_splits, + sk, logits.stride(0), logits.stride(1), logits.stride(2), @@ -590,6 +599,7 @@ def _decode_softmax_reducev_fwd( MIN_BLOCK_KV=_MIN_BLOCK_KV, BLOCK_DV=BLOCK_DV, Lv=Lv, + HAS_SK=HAS_SK, num_warps=4, num_stages=2, **extra_kargs, @@ -609,6 +619,7 @@ def decode_attention_fwd_normal( max_kv_splits, sm_scale, logit_cap=0.0, + sk=None, ): _decode_att_m_fwd( q, @@ -632,6 +643,7 @@ def decode_attention_fwd_normal( kv_indptr, num_kv_splits, max_kv_splits, + sk, ) @@ -648,6 +660,7 @@ def decode_attention_fwd_grouped( max_kv_splits, sm_scale, logit_cap=0.0, + sk=None, ): _decode_grouped_att_m_fwd( q, @@ -671,6 +684,7 @@ def decode_attention_fwd_grouped( kv_indptr, num_kv_splits, max_kv_splits, + sk, ) @@ -687,6 +701,7 @@ def decode_attention_fwd( max_kv_splits, sm_scale, logit_cap=0.0, + sk=None, ): assert max_kv_splits == attn_logits.shape[2] assert q.shape[0] <= kv_indptr.shape[0] - 1 @@ -709,6 +724,7 @@ def decode_attention_fwd( max_kv_splits, sm_scale, logit_cap=logit_cap, + sk=sk, ) else: # GQA/MQA/MLA @@ -725,4 +741,5 @@ def decode_attention_fwd( max_kv_splits, sm_scale, logit_cap=logit_cap, + sk=sk, ) diff --git a/python/sglang/srt/layers/attention/triton_ops/extend_attention.py b/python/sglang/srt/layers/attention/triton_ops/extend_attention.py index 67767df9b..e1b707f39 100644 --- a/python/sglang/srt/layers/attention/triton_ops/extend_attention.py +++ b/python/sglang/srt/layers/attention/triton_ops/extend_attention.py @@ -51,6 +51,7 @@ def _fwd_kernel( kv_indices, mask_ptr, mask_indptr, + sk_ptr, sm_scale, kv_group_num, stride_qbs, @@ -78,6 +79,7 @@ def _fwd_kernel( IS_CAUSAL: tl.constexpr, SKIP_PREFIX_CUSTOM_MASK: tl.constexpr, STORE_TRANSPOSE: tl.constexpr, + HAS_SK: tl.constexpr, ): cur_seq = tl.program_id(0) cur_head = tl.program_id(1) @@ -178,13 +180,17 @@ def _fwd_kernel( final_mask &= custom_mask if SLIDING_WINDOW_SIZE > 0: # Add mask where q_id <= kv_id + sliding_window_size - window_mask = (cur_block_m * BLOCK_M + offs_m[:, None]) <= ( - start_n + offs_n[None, :] + SLIDING_WINDOW_SIZE - ) + # q_id = prefix_len + cur_m, kv_id = cur_n + window_mask = ( + cur_seq_len_prefix + cur_block_m * BLOCK_M + offs_m[:, None] + ) <= (start_n + offs_n[None, :] + SLIDING_WINDOW_SIZE) final_mask &= window_mask qk = tl.where(final_mask, qk, float("-inf")) - n_e_max = tl.maximum(tl.max(qk, 1), e_max) + row_max = tl.max(qk, 1) + row_max_fixed = tl.where(row_max == float("-inf"), -1e20, row_max) + n_e_max = tl.maximum(row_max_fixed, e_max) + re_scale = tl.exp(e_max - n_e_max) p = tl.exp(qk - n_e_max[:, None]) deno = deno * re_scale + tl.sum(p, 1) @@ -242,6 +248,7 @@ def _fwd_kernel( if logit_cap > 0: qk = logit_cap * tanh(qk / logit_cap) + final_mask = mask_m[:, None] & mask_n[None, :] if USE_CUSTOM_MASK: custom_mask = tl.load( mask_ptr @@ -254,18 +261,30 @@ def _fwd_kernel( other=0, ) custom_mask &= mask_m[:, None] & mask_n[None, :] - qk = tl.where(custom_mask, qk, float("-inf")) + final_mask &= custom_mask elif IS_CAUSAL: mask_causual = (cur_block_m * BLOCK_M + offs_m[:, None]) >= ( start_n + offs_n[None, :] ) mask_causual &= mask_m[:, None] & mask_n[None, :] - qk = tl.where(mask_causual, qk, float("-inf")) + final_mask &= mask_causual else: mask_non_causal = mask_m[:, None] & mask_n[None, :] - qk = tl.where(mask_non_causal, qk, float("-inf")) + final_mask &= mask_non_causal + + if SLIDING_WINDOW_SIZE > 0: + # Add mask where q_id <= kv_id + sliding_window_size + window_mask = (cur_block_m * BLOCK_M + offs_m[:, None]) <= ( + start_n + offs_n[None, :] + SLIDING_WINDOW_SIZE + ) + final_mask &= window_mask + + qk = tl.where(final_mask, qk, float("-inf")) + + row_max = tl.max(qk, 1) + row_max_fixed = tl.where(row_max == float("-inf"), -1e20, row_max) + n_e_max = tl.maximum(row_max_fixed, e_max) - n_e_max = tl.maximum(tl.max(qk, 1), e_max) re_scale = tl.exp(e_max - n_e_max) p = tl.exp(qk - n_e_max[:, None]) deno = deno * re_scale + tl.sum(p, 1) @@ -283,6 +302,10 @@ def _fwd_kernel( e_max = n_e_max + if HAS_SK: + cur_sk = tl.load(sk_ptr + cur_head) + deno += tl.exp(cur_sk - e_max) + offs_o = ( (cur_seq_extend_start_idx + cur_block_m * BLOCK_M + offs_m[:, None]) * stride_obs @@ -321,6 +344,7 @@ def extend_attention_fwd( logit_cap=0.0, skip_prefix_custom_mask=True, sliding_window_size=-1, + sk=None, ): """ q_extend, k_extend, v_extend, o_extend: contiguous tensors @@ -386,6 +410,8 @@ def extend_attention_fwd( # Skip custom mask for prefix part SKIP_PREFIX_CUSTOM_MASK = skip_prefix_custom_mask + HAS_SK = sk is not None + grid = (batch_size, head_num, triton.cdiv(max_len_extend, BLOCK_M)) num_stages = 1 @@ -405,6 +431,7 @@ def extend_attention_fwd( kv_indices, custom_mask, mask_indptr, + sk, sm_scale, kv_group_num, q_extend.stride(0), @@ -431,6 +458,7 @@ def extend_attention_fwd( USE_CUSTOM_MASK=USE_CUSTOM_MASK, IS_CAUSAL=is_causal, SKIP_PREFIX_CUSTOM_MASK=SKIP_PREFIX_CUSTOM_MASK, + HAS_SK=HAS_SK, STORE_TRANSPOSE=_is_hip, num_warps=num_warps, num_stages=num_stages, diff --git a/python/sglang/srt/layers/linear.py b/python/sglang/srt/layers/linear.py index 9e765ebf9..782699749 100644 --- a/python/sglang/srt/layers/linear.py +++ b/python/sglang/srt/layers/linear.py @@ -1191,11 +1191,6 @@ class RowParallelLinear(LinearBase): else self.weight_loader ), ) - if not reduce_results and (bias and not skip_bias_add): - raise ValueError( - "When not reduce the results, adding bias to the " - "results can lead to incorrect results" - ) if bias: self.bias = Parameter(torch.empty(self.output_size, dtype=params_dtype)) diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py index 74558fd9b..56ffe371b 100644 --- a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py +++ b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py @@ -134,6 +134,10 @@ class FusedMoE(torch.nn.Module): no_combine: bool = False, routed_scaling_factor: Optional[float] = None, enable_flashinfer_cutlass_moe: Optional[bool] = False, + activation_alpha: Optional[float] = None, + swiglu_limit: Optional[float] = None, + use_weight_loader_fused: bool = False, + with_bias=False, ): super().__init__() @@ -148,6 +152,10 @@ class FusedMoE(torch.nn.Module): self.expert_map_cpu = None self.expert_map_gpu = None + # For activation + self.activation_alpha = activation_alpha + self.swiglu_limit = swiglu_limit + if enable_flashinfer_cutlass_moe and quant_config is None: logger.warning("Disable flashinfer MoE when quantization config is None.") enable_flashinfer_cutlass_moe = False @@ -191,7 +199,7 @@ class FusedMoE(torch.nn.Module): if quant_config is None: self.quant_method: Optional[QuantizeMethodBase] = UnquantizedFusedMoEMethod( - self.use_triton_kernels + self.use_triton_kernels, with_bias=with_bias ) else: self.quant_method = quant_config.get_quant_method(self, prefix) @@ -206,7 +214,12 @@ class FusedMoE(torch.nn.Module): intermediate_size=self.intermediate_size_per_partition, intermediate_size_per_partition=self.intermediate_size_per_partition, params_dtype=params_dtype, - weight_loader=self.weight_loader, + weight_loader=( + self.weight_loader + if not use_weight_loader_fused + else self.weight_loader_fused + ), + with_bias=with_bias, ) def _load_per_tensor_weight_scale( @@ -234,6 +247,7 @@ class FusedMoE(torch.nn.Module): shard_id: str, loaded_weight: torch.Tensor, tp_rank: int, + is_bias: bool = False, ): # Load grouped weight scales for group quantization # or model weights @@ -244,14 +258,16 @@ class FusedMoE(torch.nn.Module): loaded_weight=loaded_weight, expert_data=expert_data, tp_rank=tp_rank, + is_bias=is_bias, ) - elif shard_id in ("w1", "w3"): + elif shard_id in ("w1", "w3", "w13"): self._load_w13( shard_id=shard_id, shard_dim=shard_dim, loaded_weight=loaded_weight, expert_data=expert_data, tp_rank=tp_rank, + is_bias=is_bias, ) def _load_per_channel_weight_scale( @@ -281,17 +297,30 @@ class FusedMoE(torch.nn.Module): shard_id: str, loaded_weight: torch.Tensor, tp_rank: int, + is_bias: bool = False, ): # Index the loaded weight for tp sharding. # gate_up_proj: "MergedColumnParallel", so tp sharding on output_dim - shard_size = expert_data.shape[shard_dim] // 2 + assert shard_id in {"w1", "w3", "w13"} + + if is_bias: + # if this weight is a bias, the last dimension must be the sharded dimension + shard_dim = -1 + + if shard_id in {"w1", "w3"}: + # non-fused version + shard_size = expert_data.shape[shard_dim] // 2 + elif shard_id in {"w13"}: + # fused version + shard_size = expert_data.shape[shard_dim] + else: + raise NotImplementedError # Narrow parameter and load. # w1, gate_proj: Load into first logical weight of w13. # w3, up_proj: Load into second logical weight of w13. # trtllm cutlass kernel assumes differently - assert shard_id in ("w1", "w3") switch_w13 = getattr(self.quant_method, "load_up_proj_weight_first", False) if (switch_w13 and shard_id == "w1") or (not switch_w13 and shard_id == "w3"): start = shard_size @@ -310,7 +339,8 @@ class FusedMoE(torch.nn.Module): ) else: if not self.use_presharded_weights: - if self.use_triton_kernels: + if not is_bias and self.use_triton_kernels: + # do not transpose for bias loaded_weight = loaded_weight.transpose(-2, -1) loaded_weight = loaded_weight.narrow( shard_dim, shard_size * tp_rank, shard_size @@ -326,6 +356,7 @@ class FusedMoE(torch.nn.Module): shard_id: str, loaded_weight: torch.Tensor, tp_rank: int, + is_bias: bool = False, ): """Load w2 weights for down projection. @@ -356,7 +387,14 @@ class FusedMoE(torch.nn.Module): # Index the loaded weight for tp sharding. # down_proj: "RowParallel" so tp sharding on input_dim # Narrow parameter and load. - shard_size = expert_data.shape[shard_dim] + if is_bias: + # this expert_data is a bias, not weight, + # for w2_bias in TP, it does not need to be sharded + shard_size = expert_data.shape[-1] + else: + # this parameter is a weight matrix + # for w2 in TP, it shards the input_features, i.e., shard_dim=2 + shard_size = expert_data.shape[shard_dim] if _is_cpu: expert_data, loaded_weight = narrow_padded_param_and_loaded_weight( @@ -369,7 +407,7 @@ class FusedMoE(torch.nn.Module): not self.use_presharded_weights, ) else: - if not self.use_presharded_weights: + if not is_bias and not self.use_presharded_weights: if self.use_triton_kernels: loaded_weight = loaded_weight.transpose(-2, -1) if shard_size * tp_rank + shard_size > loaded_weight.shape[shard_dim]: @@ -658,6 +696,68 @@ class FusedMoE(torch.nn.Module): ) return + def weight_loader_fused( + self, + param: torch.nn.Parameter, + loaded_weight: torch.Tensor, + weight_name: str, + shard_id: str, + ) -> None: + tp_rank = self.moe_tp_rank + + # compressed-tensors checkpoints with packed weights are stored flipped + # TODO: check self.quant_method.quant_config.quant_format + # against known CompressionFormat enum values that have this quality + loaded_weight = ( + loaded_weight.t().contiguous() + if ( + self.quant_method.__class__.__name__ + == "CompressedTensorsWNA16MoEMethod" + ) + else loaded_weight + ) + + if shard_id not in ("w13", "w2"): + raise ValueError(f"shard_id must be ['w13','w2'] but " f"got {shard_id}.") + + # Fetch the dim to shard the parameter/loaded weight + # based on the shard id. This will be whatever + # dimension intermediate_size is used. + SHARD_ID_TO_SHARDED_DIM = {"w13": 1, "w2": 2} + SHARD_ID_TO_SHARDED_DIM_TRANSPOSE = {"w13": 2, "w2": 1} + + expert_data = param.data + is_bias = expert_data.dim() == 2 + + # is_transposed: if the dim to shard the weight + # should be flipped. Required by GPTQ, compressed-tensors + # should be whatever dimension intermediate_size is + is_transposed = getattr(param, "is_transposed", False) + + if self.use_triton_kernels: + is_transposed = True + shard_dim = ( + SHARD_ID_TO_SHARDED_DIM[shard_id] + if not is_transposed + else SHARD_ID_TO_SHARDED_DIM_TRANSPOSE[shard_id] + ) + + # Case model weights + if "weight" in weight_name: + self._load_model_weight_or_group_weight_scale( + shard_id=shard_id, + shard_dim=shard_dim, + loaded_weight=loaded_weight, + expert_data=expert_data, + tp_rank=tp_rank, + is_bias=is_bias, + ) + return + else: + logging.warning( + f"Unsupported weight_name {weight_name} for FusedMoE weight_loader_fused. Nothing is loaded." + ) + def forward(self, hidden_states: torch.Tensor, topk_output: StandardTopKOutput): assert self.quant_method is not None @@ -673,6 +773,12 @@ class FusedMoE(torch.nn.Module): # Matrix multiply. with use_symmetric_memory(get_tp_group()) as sm: + kwargs = {} + if self.activation_alpha is not None: + kwargs["activation_alpha"] = self.activation_alpha + if self.swiglu_limit is not None: + kwargs["swiglu_limit"] = self.swiglu_limit + final_hidden_states = self.quant_method.apply( layer=self, x=hidden_states, @@ -691,6 +797,7 @@ class FusedMoE(torch.nn.Module): == "ModelOptNvFp4FusedMoEMethod" else {} ), + **kwargs, ) sm.tag(final_hidden_states) @@ -728,6 +835,25 @@ class FusedMoE(torch.nn.Module): ] ] + @classmethod + def make_expert_params_mapping_fused( + cls, + ckpt_gate_up_proj_name: str, + ckpt_down_proj_name: str, + ckpt_gate_up_proj_bias_name: str, + ckpt_down_proj_bias_name: str, + ): + return [ + ("experts.w13_weight", f"experts.{ckpt_gate_up_proj_name}", "w13"), + ( + "experts.w13_weight_bias", + f"experts.{ckpt_gate_up_proj_bias_name}", + "w13", + ), + ("experts.w2_weight", f"experts.{ckpt_down_proj_name}", "w2"), + ("experts.w2_weight_bias", f"experts.{ckpt_down_proj_bias_name}", "w2"), + ] + @classmethod def make_expert_input_scale_params_mapping( cls, diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py b/python/sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py index eed33c5e8..36466661d 100644 --- a/python/sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +++ b/python/sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py @@ -6,15 +6,50 @@ from typing import TYPE_CHECKING, Optional import torch from sgl_kernel import gelu_and_mul, silu_and_mul -from triton_kernels.matmul_ogs import matmul_ogs +from triton_kernels.matmul_ogs import ( + FlexCtx, + FnSpecs, + FusedActivation, + PrecisionConfig, + matmul_ogs, +) +from triton_kernels.numerics import InFlexData from triton_kernels.routing import GatherIndx, RoutingData, ScatterIndx - -from sglang.srt.utils import direct_register_custom_op +from triton_kernels.swiglu import swiglu_fn if TYPE_CHECKING: from sglang.srt.layers.moe.topk import TopKOutput +def quantize(w, dtype, dev, **opt): + if dtype == "bf16": + return w.to(torch.bfloat16), InFlexData() + elif dtype == "fp8": + wq = w.to(torch.float8_e4m3fn).transpose(-1, -2).contiguous().transpose(-1, -2) + return ( + wq, + InFlexData(dtype=wq.dtype, scale=w.abs().max().unsqueeze(0)), + MicroscalingCtx(), + ) + else: + assert dtype == "mx4", f"{dtype=}" + swizzle_mx_scale = opt["swizzle_mx_scale"] + swizzle_axis = 2 if swizzle_mx_scale else None + w = w.to(torch.bfloat16) + w, mx_scales, weight_scale_shape = downcast_to_mxfp( + w, torch.uint8, axis=1, swizzle_axis=swizzle_axis + ) + return ( + w, + InFlexData(), + MicroscalingCtx( + weight_scale=mx_scales, + swizzle_mx=swizzle_mx_scale, + actual_weight_scale_shape=weight_scale_shape, + ), + ) + + def triton_kernel_moe_forward( hidden_states: torch.Tensor, w1: torch.Tensor, @@ -146,3 +181,143 @@ def triton_kernel_fused_experts( ) return intermediate_cache3 + + +def triton_kernel_moe_with_bias_forward( + hidden_states: torch.Tensor, + w1: torch.Tensor, + b1: torch.Tensor, + w2: torch.Tensor, + b2: torch.Tensor, + topk_output: TopKOutput, + inplace: bool = False, + activation: str = "silu", + use_fp8_w8a8: bool = False, + per_channel_quant: bool = False, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, + w1_scale: Optional[torch.Tensor] = None, + w2_scale: Optional[torch.Tensor] = None, + a1_scale: Optional[torch.Tensor] = None, + a2_scale: Optional[torch.Tensor] = None, + block_shape: Optional[list[int]] = None, + activation_alpha: Optional[float] = None, + swiglu_limit: Optional[int] = None, +) -> torch.Tensor: + assert topk_output.format.is_triton_kernel() + routing_data, gather_idx, scatter_idx = topk_output + + return triton_kernel_fused_experts_with_bias( + hidden_states, + w1, + b1, + w2, + b2, + routing_data, + gather_idx, + scatter_idx, + inplace=inplace, + activation=activation, + use_fp8_w8a8=use_fp8_w8a8, + per_channel_quant=per_channel_quant, + global_num_experts=global_num_experts, + expert_map=expert_map, + w1_scale=w1_scale, + w2_scale=w2_scale, + a1_scale=a1_scale, + a2_scale=a2_scale, + block_shape=block_shape, + activation_alpha=activation_alpha, + swiglu_limit=swiglu_limit, + ) + + +def triton_kernel_fused_experts_with_bias( + hidden_states: torch.Tensor, + w1: torch.Tensor, + b1: torch.Tensor, + w2: torch.Tensor, + b2: torch.Tensor, + routing_data: RoutingData, + gather_indx: GatherIndx, + scatter_indx: ScatterIndx, + inplace: bool = False, + activation: str = "silu", + use_fp8_w8a8: bool = False, + per_channel_quant: bool = False, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, + w1_scale: Optional[torch.Tensor] = None, + w2_scale: Optional[torch.Tensor] = None, + a1_scale: Optional[torch.Tensor] = None, + a2_scale: Optional[torch.Tensor] = None, + block_shape: Optional[list[int]] = None, + activation_alpha: Optional[float] = None, + swiglu_limit: Optional[int] = None, +) -> torch.Tensor: + # print(f"here in triton moe with bias", b1.shape, b1.dtype, b2.shape, b2.dtype) + assert use_fp8_w8a8 == False, "use_fp8_w8a8 is not supported" + assert per_channel_quant == False, "per_channel_quant is not supported" + assert expert_map == None, "expert_map is not supported" + assert w1_scale == None, "w1_scale is not supported" + assert w2_scale == None, "w2_scale is not supported" + assert a1_scale == None, "a1_scale is not supported" + assert a2_scale == None, "a2_scale is not supported" + assert block_shape == None, "block_shape is not supported" + + # type check + assert hidden_states.dtype == torch.bfloat16, "hidden_states must be bfloat16" + assert w1.dtype == torch.bfloat16, "w1 must be bfloat16" + assert w2.dtype == torch.bfloat16, "w2 must be bfloat16" + + # Shape check + assert hidden_states.ndim == 2, "hidden_states must be 2D" + assert ( + hidden_states.shape[-1] == w1.shape[-2] + ), f"hidden_states shape[-1] {hidden_states.shape} must be equal to w1 shape[-2] {w1.shape}" + assert ( + w2.shape[-1] == w1.shape[1] + ), f"w2 shape[-1] {w2.shape[-1]} must be equal to w1 shape[1] {w1.shape[1]}" + + # feature check + assert inplace == False, "Inplace is not supported in new triton MoE kernel" + + E, _, _ = w1.shape + + if global_num_experts == -1: + global_num_experts = E + + device = "cuda" + optg = dict() + w1, w1_flex = quantize(w1, "bf16", device, **optg) + w1_pcg = PrecisionConfig(flex_ctx=FlexCtx(rhs_data=w1_flex)) + + w2, w2_flex = quantize(w2, "bf16", device, **optg) + w2_pcg = PrecisionConfig(flex_ctx=FlexCtx(rhs_data=w2_flex)) + + act = FusedActivation( + FnSpecs("swiglu", swiglu_fn, ("alpha", "limit")), + (activation_alpha, swiglu_limit), + 2, + ) + + intermediate_cache = matmul_ogs( + hidden_states, + w1, + b1, + routing_data, + gather_indx=gather_indx, + precision_config=w1_pcg, + gammas=None, + fused_activation=act, + ) + + return matmul_ogs( + intermediate_cache, + w2, + b2, + routing_data, + scatter_indx=scatter_indx, + precision_config=w2_pcg, + gammas=routing_data.gate_scal, + ) diff --git a/python/sglang/srt/layers/quantization/fp8_utils.py b/python/sglang/srt/layers/quantization/fp8_utils.py index 3ab8634ac..d7638ce18 100644 --- a/python/sglang/srt/layers/quantization/fp8_utils.py +++ b/python/sglang/srt/layers/quantization/fp8_utils.py @@ -4,6 +4,7 @@ import torch from sglang.srt.layers.quantization import deep_gemm_wrapper from sglang.srt.layers.quantization.fp8_kernel import sglang_per_token_group_quant_fp8 +from sglang.srt.layers.quantization.mxfp4_tensor import MXFP4QuantizeUtil from sglang.srt.layers.utils import is_sm100_supported try: @@ -26,6 +27,7 @@ from sglang.srt.layers.quantization.fp8_kernel import ( ) from sglang.srt.utils import ( align, + ceil_div, get_bool_env_var, get_cuda_version, get_device_capability, @@ -307,6 +309,33 @@ def triton_w8a8_block_fp8_linear( return output.to(dtype=input_2d.dtype).view(*output_shape) +def dequant_mxfp4( + w_block: torch.Tensor, + w_scale: torch.Tensor, + out_dtype, +) -> torch.Tensor: + """ + :param w_block: (batch, n, k, 16), uint8, pack two mxfp4 into one byte + :param w_scale: (batch, n, k), uint8 + :return: (batch, n, k * 32), float32 + """ + + assert w_block.dtype == torch.uint8 + assert w_scale.dtype == torch.uint8 + + batch, n, k, pack_dim = w_block.shape + batch_, n_, k_ = w_scale.shape + assert pack_dim == 16 + assert batch == batch_ + assert n == n_ + assert k == k_ + + out_raw = MXFP4QuantizeUtil.dequantize( + quantized_data=w_block, scale=w_scale, dtype=out_dtype, block_sizes=[32] + ) + return out_raw.reshape(batch, n, k * 32) + + def input_to_float8( x: torch.Tensor, dtype: torch.dtype = fp8_dtype ) -> Tuple[torch.Tensor, torch.Tensor]: diff --git a/python/sglang/srt/layers/quantization/mxfp4_tensor.py b/python/sglang/srt/layers/quantization/mxfp4_tensor.py new file mode 100644 index 000000000..e7b9a8346 --- /dev/null +++ b/python/sglang/srt/layers/quantization/mxfp4_tensor.py @@ -0,0 +1,133 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch + + +# https://github.com/NVIDIA/TensorRT-Model-Optimizer/blob/main/modelopt/torch/quantization/qtensor/mxfp4_tensor.py +class MXFP4QuantizeUtil: + E2M1_max = 6.0 + + E2M1_values = [0, 0.5, 1, 1.5, 2, 3, 4, 6] + E2M1_bounds = torch.tensor([0.25, 0.75, 1.25, 1.75, 2.5, 3.5, 5]) + + @classmethod + def quantize(cls, input: torch.Tensor, block_size: int | None) -> tuple: + """Converting a tensor to a quantized format based on MXFP4 quantization. Only E4M3 is supported. + Args: + input (torch.Tensor): The input tensor to be quantized. + block_sizes (dict | None): The block sizes for quantization. + """ + + def cast_fp4(x): + sign = torch.sign(x) + sign_bit = (2 - sign) // 2 + ord_ = torch.sum( + (x.abs().unsqueeze(-1) - cls.E2M1_bounds.to(x.device)) > 0, dim=-1 + ) + fp4_val = (sign_bit * 0b1000 + ord_).to(torch.uint8) + return fp4_val + + def fuse_uint4_to_uint8(x): + # If the last dimension is odd, pad with zeros + # If this behavior is not desired, please modify the code accordingly + left_side = x[..., 0::2] # Even indices (0, 2, 4...) + right_side = x[..., 1::2] # Odd indices (1, 3, 5...) + new_data = ( + right_side.clone() << 4 + ) # Put odd indices (higher addresses) in high bits + new_data[ + ..., : left_side.shape[-1] + ] += left_side # Put even indices in low bits + return new_data + + if block_size is None: + block_size = 32 + + original_shape = input.shape + original_dtype = input.dtype + input = input.view(-1, block_size) + # get scales + input_amax = input.abs().max(dim=-1, keepdim=True).values + descale = input_amax / cls.E2M1_max + min_value = torch.tensor(-127.0, device=descale.device) + e8m0_scale = torch.ceil(torch.maximum(torch.log2(descale), min_value)) + + input = (input / torch.exp2(e8m0_scale)).view(original_shape) + input_q = cast_fp4(input) + input_q = fuse_uint4_to_uint8(input_q) + e8m0_scale = (e8m0_scale + 127).to(torch.uint8) + return cls(original_shape, original_dtype, input_q), e8m0_scale + + @classmethod + def dequantize(cls, quantized_data, dtype: torch.dtype, scale, block_sizes): + """Dequantze MXFP4 packed tensor to a target dtype.""" + + def unfuse_uint8_to_uint4(x): + """Unfuse uint8 values back to uint4 values. + This is the inverse operation of fuse_uint4_to_uint8. + """ + # Extract the lower 4 bits (even indices) + left_side = x & 0x0F + + # Extract the upper 4 bits (odd indices) + right_side = (x >> 4) & 0x0F + + # Create a new tensor with alternating values + shape = list(x.shape) + shape[-1] = shape[-1] * 2 + result = torch.zeros(shape, dtype=torch.uint8, device=x.device) + + # Fill in the values - even indices get low bits, odd indices get high bits + result[..., 0::2] = left_side # Even indices from low bits + result[..., 1::2] = right_side # Odd indices from high bits + + return result + + e8m0_scale = scale + block_size = block_sizes[-1] + + # Unfuse the uint8 values back to uint4 + x_unfused = unfuse_uint8_to_uint4(quantized_data) + # Extract sign and magnitude + sign = 1 - 2 * ((x_unfused & 0b1000) >> 3).to( + torch.float32 + ) # Extract sign bit and convert to +1/-1 + magnitude = x_unfused & 0b0111 # Extract magnitude bits + magnitude = magnitude.to(torch.long) + + # Create a tensor with the E2M1 values + values = torch.tensor(cls.E2M1_values, device=quantized_data.device) + + # Use gather to index the values tensor properly + # We need to reshape magnitude to match the dimensions we want to gather along + original_shape = magnitude.shape + x_float = values[magnitude.reshape(-1)].reshape(original_shape) + + # Apply sign and scale + x_float = sign.float() * x_float + + # Reshape to apply block-wise scaling + x_float = x_float.reshape(-1, block_size) + + # Apply the E8M0 scale + scale_factor = torch.exp2(e8m0_scale.float() - 127) + scale_factor = scale_factor.reshape(-1, 1) # Reshape for proper broadcasting + + # Apply scaling and reshape back to original shape + x_float = x_float * scale_factor + + # Reshape back to the original shape + return x_float.reshape(original_shape).to(dtype) diff --git a/python/sglang/srt/layers/quantization/unquant.py b/python/sglang/srt/layers/quantization/unquant.py index 38b889695..8fc4a5be1 100644 --- a/python/sglang/srt/layers/quantization/unquant.py +++ b/python/sglang/srt/layers/quantization/unquant.py @@ -126,17 +126,23 @@ class UnquantizedLinearMethod(LinearMethodBase): class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): """MoE method without quantization.""" - def __init__(self, use_triton_kernels: bool = False): + def __init__(self, use_triton_kernels: bool = False, with_bias: bool = False): super().__init__() self.use_triton_kernels = use_triton_kernels + self.with_bias = with_bias self.triton_kernel_moe_forward = None + self.triton_kernel_moe_with_bias_forward = None if torch.cuda.is_available() and has_triton_kernels: from sglang.srt.layers.moe.fused_moe_triton.triton_kernels_moe import ( triton_kernel_moe_forward as _tk_forward, ) + from sglang.srt.layers.moe.fused_moe_triton.triton_kernels_moe import ( + triton_kernel_moe_with_bias_forward as _tk_with_bias_forward, + ) self.triton_kernel_moe_forward = _tk_forward + self.triton_kernel_moe_with_bias_forward = _tk_with_bias_forward def create_weights( self, @@ -158,6 +164,14 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): layer.register_parameter("w13_weight", w13_weight) set_weight_attrs(w13_weight, extra_weight_attrs) + if self.with_bias: + w13_weight_bias = torch.nn.Parameter( + torch.empty(num_experts, 2 * intermediate_size, dtype=torch.float32), + requires_grad=False, + ) + layer.register_parameter("w13_weight_bias", w13_weight_bias) + set_weight_attrs(w13_weight_bias, extra_weight_attrs) + # down_proj (row parallel) w2_weight_n, w2_weight_k = ( hidden_size, @@ -172,6 +186,14 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): layer.register_parameter("w2_weight", w2_weight) set_weight_attrs(w2_weight, extra_weight_attrs) + if self.with_bias: + w2_weight_bias = torch.nn.Parameter( + torch.empty(num_experts, hidden_size, dtype=torch.float32), + requires_grad=False, + ) + layer.register_parameter("w2_weight_bias", w2_weight_bias) + set_weight_attrs(w2_weight_bias, extra_weight_attrs) + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: if _use_aiter: layer.w13_weight = torch.nn.Parameter( @@ -202,7 +224,14 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): inplace: bool = True, no_combine: bool = False, routed_scaling_factor: Optional[float] = None, + activation_alpha: Optional[float] = None, + swiglu_limit: Optional[float] = None, ) -> torch.Tensor: + kwargs = {} + if activation_alpha is not None: + kwargs["activation_alpha"] = activation_alpha + if swiglu_limit is not None: + kwargs["swiglu_limit"] = swiglu_limit return self.forward( x=x, @@ -213,6 +242,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): inplace=inplace, no_combine=no_combine, routed_scaling_factor=routed_scaling_factor, + **kwargs, ) def forward_cuda( @@ -226,15 +256,30 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): inplace: bool = True, no_combine: bool = False, routed_scaling_factor: Optional[float] = None, + activation_alpha: Optional[float] = None, + swiglu_limit: Optional[float] = None, ) -> torch.Tensor: if self.use_triton_kernels: - return self.triton_kernel_moe_forward( - hidden_states=x, - w1=layer.w13_weight, - w2=layer.w2_weight, - topk_output=topk_output, - ) + if self.with_bias: + return self.triton_kernel_moe_with_bias_forward( + hidden_states=x, + w1=layer.w13_weight, + w2=layer.w2_weight, + b1=layer.w13_weight_bias, + b2=layer.w2_weight_bias, + topk_output=topk_output, + activation=activation, + activation_alpha=activation_alpha, + swiglu_limit=swiglu_limit, + ) + else: + return self.triton_kernel_moe_forward( + hidden_states=x, + w1=layer.w13_weight, + w2=layer.w2_weight, + topk_output=topk_output, + ) else: if _use_aiter: assert not no_combine, "unsupported" diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 99ea56965..3452608b3 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -917,8 +917,10 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): is_hybrid = False if isinstance(token_to_kv_pool_allocator, SWATokenToKVPoolAllocator): - assert isinstance(tree_cache, SWARadixCache) or isinstance( - tree_cache, SWAChunkCache + assert ( + tree_cache is None + or isinstance(tree_cache, SWARadixCache) + or isinstance(tree_cache, SWAChunkCache) ), "SWARadixCache or SWAChunkCache is required for SWATokenToKVPoolAllocator" is_hybrid = True diff --git a/python/sglang/srt/models/gpt_oss.py b/python/sglang/srt/models/gpt_oss.py new file mode 100644 index 000000000..cf40c652b --- /dev/null +++ b/python/sglang/srt/models/gpt_oss.py @@ -0,0 +1,923 @@ +# Copyright 2023-2024 SGLang Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + + +"""Inference-only GptOss model compatible with HuggingFace weights.""" + +import logging +from collections.abc import Iterable +from functools import partial +from typing import Any, Dict, List, Optional, Tuple, Union + +import torch +from torch import nn +from transformers import PretrainedConfig + +from sglang.srt.distributed import ( + get_moe_tensor_parallel_rank, + get_pp_group, + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + tensor_model_parallel_all_reduce, +) +from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder +from sglang.srt.eplb.expert_location import ModelConfigForExpertLocation +from sglang.srt.layers.communicator import LayerCommunicator, LayerScatterModes +from sglang.srt.layers.dp_attention import ( + get_attention_tp_rank, + get_attention_tp_size, + get_local_attention_dp_size, +) +from sglang.srt.layers.layernorm import RMSNorm +from sglang.srt.layers.linear import ( + QKVParallelLinear, + ReplicatedLinear, + RowParallelLinear, +) +from sglang.srt.layers.logits_processor import LogitsProcessor +from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class +from sglang.srt.layers.moe.topk import TopK +from sglang.srt.layers.moe.utils import DeepEPMode +from sglang.srt.layers.quantization.base_config import QuantizationConfig +from sglang.srt.layers.quantization.fp8_utils import dequant_mxfp4 +from sglang.srt.layers.radix_attention import RadixAttention +from sglang.srt.layers.rotary_embedding import get_rope +from sglang.srt.layers.utils import PPMissingLayer, get_layer_id +from sglang.srt.layers.vocab_parallel_embedding import ( + ParallelLMHead, + VocabParallelEmbedding, +) +from sglang.srt.managers.schedule_batch import global_server_args_dict +from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors +from sglang.srt.model_loader.weight_utils import default_weight_loader +from sglang.srt.utils import add_prefix, make_layers + + +class GptOssConfig(PretrainedConfig): + model_type = "gpt_oss" + + def __init__(self, **kwargs): + super().__init__(**kwargs) + + +logger = logging.getLogger(__name__) + + +# Aligned with HF's implementation, using sliding window inclusive with the last token +# SGLang assumes exclusive +def get_attention_sliding_window_size(config): + return config.sliding_window - 1 + + +class GptOssSparseMoeBlock(nn.Module): + def __init__( + self, + layer_id: int, + config: GptOssConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + self.tp_size = get_tensor_model_parallel_world_size() + self.layer_id = layer_id + self.activation = config.hidden_act + self.activation_alpha = getattr(config, "hidden_act_alpha", 1.702) + self.swiglu_limit = config.swiglu_limit + if self.tp_size > config.num_local_experts: + raise ValueError( + f"Tensor parallel size {self.tp_size} is greater than " + f"the number of experts {config.num_local_experts}." + ) + + self.topk = TopK( + top_k=config.num_experts_per_tok, + renormalize=True, + ) + + experts_type = get_moe_impl_class() + extra_kwargs = {} + if experts_type.__name__ == "FusedMoE": + extra_kwargs = { + "enable_flashinfer_cutlass_moe": global_server_args_dict[ + "enable_flashinfer_cutlass_moe" + ], + "use_weight_loader_fused": True, # for moe gate_up_proj and down_proj and their bias loading + } + self.experts = experts_type( + num_experts=config.num_local_experts + + global_server_args_dict["ep_num_redundant_experts"], + top_k=config.num_experts_per_tok, + layer_id=layer_id, + hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, + quant_config=quant_config, + activation=self.activation, + activation_alpha=self.activation_alpha, + swiglu_limit=self.swiglu_limit, + with_bias=True, + prefix=add_prefix("experts", prefix), + **( + dict(deepep_mode=DeepEPMode[global_server_args_dict["deepep_mode"]]) + if global_server_args_dict["moe_a2a_backend"].is_deepep() + else {} + ), + **extra_kwargs, + ) + + self.router = ReplicatedLinear( + config.hidden_size, + config.num_local_experts, + bias=True, + quant_config=None, + prefix=add_prefix("gate", prefix), + params_dtype=config.torch_dtype, + ) + + def forward( + self, hidden_states: torch.Tensor, forward_batch: Optional[ForwardBatch] = None + ) -> torch.Tensor: + if not global_server_args_dict["moe_a2a_backend"].is_deepep(): + return self.forward_normal(hidden_states) + else: + raise Exception("forward_deepep branch not implemented yet") + + def get_moe_weights(self): + return [ + x.data + for name, x in self.experts.named_parameters() + if name not in ["correction_bias"] + ] + + def forward_normal(self, hidden_states: torch.Tensor) -> torch.Tensor: + num_tokens, hidden_dim = hidden_states.shape + hidden_states = hidden_states.view(-1, hidden_dim) + + # router_logits: (num_tokens, n_experts) + router_logits, _ = self.router(hidden_states) + + kwargs = {"hidden_states": hidden_states} + if self.topk is not None: + kwargs["topk_output"] = self.topk(hidden_states, router_logits) + else: + kwargs["router_logits"] = router_logits + final_hidden_states = self.experts(**kwargs) + + if self.tp_size > 1: + final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states) + + ans = final_hidden_states.view(num_tokens, hidden_dim) + return ans + + +class GptOssAttention(nn.Module): + def __init__( + self, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + layer_id: int = 0, + rope_theta: float = 10000, + rope_scaling: Optional[Dict[str, Any]] = None, + max_position_embeddings: int = 8192, + head_dim: Optional[int] = None, + rms_norm_eps: float = 1e-06, + attention_bias: bool = False, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + sliding_window_size: int = -1, # if -1, normal attention, else, window attention. + layer_type: str = "", + params_dtype: torch.dtype = torch.bfloat16, + ) -> None: + super().__init__() + self.hidden_size = hidden_size + self.sliding_window_size = sliding_window_size + + attn_tp_rank = get_attention_tp_rank() + attn_tp_size = get_attention_tp_size() + + self.total_num_heads = num_heads + assert self.total_num_heads % attn_tp_size == 0 + self.num_heads = self.total_num_heads // attn_tp_size + self.total_num_kv_heads = num_kv_heads + if self.total_num_kv_heads >= attn_tp_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % attn_tp_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert attn_tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // attn_tp_size) + self.head_dim = head_dim or hidden_size // self.total_num_heads + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = self.head_dim**-0.5 + self.rope_theta = rope_theta + self.max_position_embeddings = max_position_embeddings + self.tp_rank = get_tensor_model_parallel_rank() + + self.qkv_proj = QKVParallelLinear( + hidden_size, + self.head_dim, + self.total_num_heads, + self.total_num_kv_heads, + bias=attention_bias, + params_dtype=params_dtype, + quant_config=quant_config, + tp_rank=attn_tp_rank, + tp_size=attn_tp_size, + prefix=add_prefix("qkv_proj", prefix), + ) + + self.sinks = nn.Parameter( + torch.empty(self.num_heads, dtype=params_dtype), requires_grad=False + ) + + self.o_proj = RowParallelLinear( + self.total_num_heads * self.head_dim, + hidden_size, + bias=attention_bias, + quant_config=quant_config, + tp_rank=attn_tp_rank, + tp_size=attn_tp_size, + reduce_results=False, + params_dtype=params_dtype, + prefix=add_prefix("o_proj", prefix), + ) + + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=max_position_embeddings, + base=rope_theta, + rope_scaling=rope_scaling, + ) + + assert layer_type in {"sliding_attention", "full_attention"} + use_sliding_window = layer_type == "sliding_attention" + self.attn = RadixAttention( + self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + layer_id=layer_id, + prefix=add_prefix("attn", prefix), + sliding_window_size=(sliding_window_size if use_sliding_window else -1), + ) + self.layer_id = layer_id + + def forward_prepare( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + forward_batch: ForwardBatch, + ): + if hidden_states.shape[0] == 0: + return hidden_states, forward_batch, None + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q, k = self.rotary_emb(positions, q, k) + inner_state = q, k, v, forward_batch + return None, forward_batch, inner_state + + def forward_core(self, intermediate_state): + hidden_states, forward_batch, inner_state = intermediate_state + if inner_state is None: + return hidden_states + attn_output = self.attn(*inner_state, sk=self.sinks) + output, _ = self.o_proj(attn_output) + return output + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + forward_batch: ForwardBatch, + ) -> torch.Tensor: + s = self.forward_prepare( + positions=positions, + hidden_states=hidden_states, + forward_batch=forward_batch, + ) + return self.forward_core(s) + + +class GptOssDecoderLayer(nn.Module): + def __init__( + self, + config: GptOssConfig, + layer_id: int, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + sliding_window_size: int | None = None, + ) -> None: + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + rope_theta = getattr(config, "rope_theta", 10000) + rope_scaling = getattr(config, "rope_scaling", None) + max_position_embeddings = getattr(config, "max_position_embeddings", 8192) + head_dim = getattr( + config, "head_dim", config.hidden_size // config.num_attention_heads + ) + rms_norm_eps = config.rms_norm_eps + attention_bias = config.attention_bias + + if sliding_window_size is None: + self.sliding_window_size = get_attention_sliding_window_size(self.config) + else: + self.sliding_window_size = sliding_window_size + + self.self_attn = GptOssAttention( + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + num_kv_heads=config.num_key_value_heads, + layer_id=layer_id, + rope_theta=rope_theta, + rope_scaling=rope_scaling, + max_position_embeddings=max_position_embeddings, + head_dim=head_dim, + rms_norm_eps=rms_norm_eps, + attention_bias=attention_bias, + quant_config=quant_config, + prefix=add_prefix("self_attn", prefix), + sliding_window_size=self.sliding_window_size, + layer_type=config.layer_types[layer_id], + params_dtype=config.torch_dtype, + ) + + self.layer_id = layer_id + + self.attn_tp_size = get_attention_tp_size() + self.attn_tp_rank = get_attention_tp_rank() + self.local_dp_size = get_local_attention_dp_size() + + # GptOss all layers are sparse and have no nextn now + self.is_layer_sparse = True + is_previous_layer_sparse = True + + self.layer_scatter_modes = LayerScatterModes.init_new( + layer_id=layer_id, + num_layers=config.num_hidden_layers, + is_layer_sparse=self.is_layer_sparse, + is_previous_layer_sparse=is_previous_layer_sparse, + ) + + if self.is_layer_sparse: + self.mlp = GptOssSparseMoeBlock( + layer_id=self.layer_id, + config=config, + quant_config=quant_config, + prefix=add_prefix("mlp", prefix), + ) + else: + raise NotImplementedError( + "Dense MLP is not implemented for GptOssDecoderLayer. " + "Please use GptOssSparseMoeBlock instead." + ) + self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + + self.layer_communicator = LayerCommunicator( + layer_scatter_modes=self.layer_scatter_modes, + input_layernorm=self.input_layernorm, + post_attention_layernorm=self.post_attention_layernorm, + ) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + forward_batch: ForwardBatch, + residual: Optional[torch.Tensor], + ) -> Tuple[torch.Tensor, torch.Tensor]: + hidden_states, residual = self.layer_communicator.prepare_attn( + hidden_states, residual, forward_batch + ) + + if hidden_states.shape[0] != 0: + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + forward_batch=forward_batch, + ) + + hidden_states, residual = self.layer_communicator.prepare_mlp( + hidden_states, residual, forward_batch + ) + + hidden_states = self.mlp(hidden_states, forward_batch) + + hidden_states, residual = self.layer_communicator.postprocess_layer( + hidden_states, residual, forward_batch + ) + + return hidden_states, residual + + +class GptOssModel(nn.Module): + def __init__( + self, + config: PretrainedConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + decoder_layer_type: type[nn.Module] = GptOssDecoderLayer, + ) -> None: + super().__init__() + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + self.pp_group = get_pp_group() + + if self.pp_group.is_first_rank: + self.embed_tokens = VocabParallelEmbedding( + config.vocab_size, + config.hidden_size, + enable_tp=not global_server_args_dict["enable_dp_attention"], + prefix=add_prefix("embed_tokens", prefix), + ) + else: + self.embed_tokens = PPMissingLayer() + + # Use the provided decoder layer type or default to GptOssDecoderLayer + decoder_layer_type = decoder_layer_type or GptOssDecoderLayer + self.layers, self.start_layer, self.end_layer = make_layers( + config.num_hidden_layers, + lambda idx, prefix: decoder_layer_type( + layer_id=idx, + config=config, + quant_config=quant_config, + prefix=prefix, + ), + pp_rank=self.pp_group.rank_in_group, + pp_size=self.pp_group.world_size, + prefix=add_prefix("layers", prefix), + ) + if self.pp_group.is_last_rank: + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + else: + self.norm = PPMissingLayer(return_tuple=True) + + self.layers_to_capture = [] + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + forward_batch: ForwardBatch, + input_embeds: torch.Tensor = None, + pp_proxy_tensors: Optional[PPProxyTensors] = None, + ) -> Union[torch.Tensor, PPProxyTensors]: + if self.pp_group.is_first_rank: + if input_embeds is None: + hidden_states = self.embed_tokens(input_ids) + else: + hidden_states = input_embeds + residual = None + else: + assert pp_proxy_tensors is not None + hidden_states = pp_proxy_tensors["hidden_states"] + residual = pp_proxy_tensors["residual"] + + aux_hidden_states = [] + for i in range(self.start_layer, self.end_layer): + with get_global_expert_distribution_recorder().with_current_layer(i): + if i in self.layers_to_capture: + aux_hidden_states.append(hidden_states + residual) + layer = self.layers[i] + hidden_states, residual = layer( + positions, hidden_states, forward_batch, residual + ) + if not self.pp_group.is_last_rank: + return PPProxyTensors( + { + "hidden_states": hidden_states, + "residual": residual, + } + ) + else: + if hidden_states.shape[0] != 0: + if residual is None: + hidden_states = self.norm(hidden_states) + else: + hidden_states, _ = self.norm(hidden_states, residual) + if len(aux_hidden_states) == 0: + return hidden_states + + return hidden_states, aux_hidden_states + + +class GptOssForCausalLM(nn.Module): + fall_back_to_pt_during_load = False + + def __init__( + self, + config: GptOssConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.pp_group = get_pp_group() + self.config = config + self.quant_config = quant_config + self.model = GptOssModel( + config, quant_config, prefix=add_prefix("model", prefix) + ) + self.lm_head = ParallelLMHead( + config.vocab_size, + config.hidden_size, + quant_config=quant_config, + prefix=add_prefix("lm_head", prefix), + use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"], + ) + self.logits_processor = LogitsProcessor(config) + self.capture_aux_hidden_states = False + + @torch.no_grad() + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + forward_batch: ForwardBatch, + input_embeds: torch.Tensor = None, + pp_proxy_tensors: Optional[PPProxyTensors] = None, + ) -> torch.Tensor: + hidden_states = self.model( + input_ids, + positions, + forward_batch, + input_embeds, + pp_proxy_tensors=pp_proxy_tensors, + ) + + aux_hidden_states = None + if self.capture_aux_hidden_states: + hidden_states, aux_hidden_states = hidden_states + + if self.pp_group.is_last_rank: + return self.logits_processor( + input_ids, + hidden_states, + self.lm_head, + forward_batch, + aux_hidden_states, + ) + else: + return hidden_states + + @property + def start_layer(self): + return self.model.start_layer + + @property + def end_layer(self): + return self.model.end_layer + + def _get_default_weight_mapping(self): + """Generate default weight name mapping for GptOss safetensors.""" + weight_mapping = {} + + # Map router weights to gate + weight_mapping["embedding.weight"] = "model.embed_tokens.weight" + weight_mapping["unembedding.weight"] = "lm_head.weight" + weight_mapping["norm.scale"] = "model.norm.weight" + for layer_id in range(self.config.num_hidden_layers): + weight_mapping[f"block.{layer_id}.attn.q_proj.weight"] = ( + f"model.layers.{layer_id}.self_attn.q_proj.weight" + ) + weight_mapping[f"block.{layer_id}.attn.q_proj.bias"] = ( + f"model.layers.{layer_id}.self_attn.q_proj.bias" + ) + + weight_mapping[f"block.{layer_id}.attn.k_proj.weight"] = ( + f"model.layers.{layer_id}.self_attn.k_proj.weight" + ) + weight_mapping[f"block.{layer_id}.attn.k_proj.bias"] = ( + f"model.layers.{layer_id}.self_attn.k_proj.bias" + ) + + weight_mapping[f"block.{layer_id}.attn.v_proj.weight"] = ( + f"model.layers.{layer_id}.self_attn.v_proj.weight" + ) + weight_mapping[f"block.{layer_id}.attn.v_proj.bias"] = ( + f"model.layers.{layer_id}.self_attn.v_proj.bias" + ) + + weight_mapping[f"block.{layer_id}.attn.out.weight"] = ( + f"model.layers.{layer_id}.self_attn.o_proj.weight" + ) + weight_mapping[f"block.{layer_id}.attn.out.bias"] = ( + f"model.layers.{layer_id}.self_attn.o_proj.bias" + ) + weight_mapping[f"block.{layer_id}.attn.sinks"] = ( + f"model.layers.{layer_id}.self_attn.sinks" + ) + weight_mapping[f"block.{layer_id}.attn.norm.scale"] = ( + f"model.layers.{layer_id}.input_layernorm.weight" + ) + + weight_mapping[f"block.{layer_id}.mlp.gate.weight"] = ( + f"model.layers.{layer_id}.mlp.router.weight" + ) + weight_mapping[f"block.{layer_id}.mlp.gate.bias"] = ( + f"model.layers.{layer_id}.mlp.router.bias" + ) + weight_mapping[f"block.{layer_id}.mlp.norm.scale"] = ( + f"model.layers.{layer_id}.post_attention_layernorm.weight" + ) + weight_mapping[f"block.{layer_id}.mlp.experts.gate_up_proj"] = ( + f"model.layers.{layer_id}.mlp.experts.gate_up_proj" + ) + weight_mapping[f"block.{layer_id}.mlp.gate_up_proj_bias"] = ( + f"model.layers.{layer_id}.mlp.experts.gate_up_proj_bias" + ) + weight_mapping[f"block.{layer_id}.mlp.down_proj"] = ( + f"model.layers.{layer_id}.mlp.experts.mlp2_weight" + ) + weight_mapping[f"block.{layer_id}.mlp.down_proj_bias"] = ( + f"model.layers.{layer_id}.mlp.experts.mlp2_bias" + ) + + return weight_mapping + + def load_weights( + self, + weights: Iterable[Tuple[str, torch.Tensor]], + is_nextn: bool = False, + weight_name_mapping: dict = None, + ): + tp_rank = get_tensor_model_parallel_rank() + if is_nextn: + logging.warning( + "Loading weights for nextn is currently not supported in GptOssForCausalLM. " + ) + return + weights = _canonicalize_weights(self.config, weights) + weights = sorted(weights, key=lambda x: x[0]) # Sort by name for consistency + + new_weights = [] + for name, p in weights: + if "qkv.weight" in name: + q_proj, k_proj, v_proj = p.split( + [ + self.config.num_attention_heads * self.config.head_dim, + self.config.num_key_value_heads * self.config.head_dim, + self.config.num_key_value_heads * self.config.head_dim, + ], + dim=0, + ) + new_weights.append( + (f"{name.replace('qkv.weight', 'q_proj.weight')}", q_proj) + ) + new_weights.append( + (f"{name.replace('qkv.weight', 'k_proj.weight')}", k_proj) + ) + new_weights.append( + (f"{name.replace('qkv.weight', 'v_proj.weight')}", v_proj) + ) + elif "qkv.bias" in name: + q_bias, k_bias, v_bias = p.split( + [ + self.config.num_attention_heads * self.config.head_dim, + self.config.num_key_value_heads * self.config.head_dim, + self.config.num_key_value_heads * self.config.head_dim, + ], + dim=0, + ) + new_weights.append( + (f"{name.replace('qkv.bias', 'q_proj.bias')}", q_bias) + ) + new_weights.append( + (f"{name.replace('qkv.bias', 'k_proj.bias')}", k_bias) + ) + new_weights.append( + (f"{name.replace('qkv.bias', 'v_proj.bias')}", v_bias) + ) + else: + new_weights.append((name, p)) + weights = new_weights + + # Use provided weight name mapping if available, otherwise use default + if weight_name_mapping is None: + weight_name_mapping = self._get_default_weight_mapping() + else: + # Merge with default mapping + default_mapping = self._get_default_weight_mapping() + default_mapping.update(weight_name_mapping) + weight_name_mapping = default_mapping + + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ] + + expert_params_mapping = get_moe_impl_class().make_expert_params_mapping_fused( + ckpt_gate_up_proj_name="gate_up_proj", + ckpt_down_proj_name="down_proj", + ckpt_gate_up_proj_bias_name="gate_up_proj_bias", + ckpt_down_proj_bias_name="down_proj_bias", + ) + + params_dict = dict(self.named_parameters()) + params_checker = {k: False for k, v in params_dict.items()} + for name, loaded_weight in weights: + loaded_weight = _WeightCreator.maybe_materialize(loaded_weight) + + # Apply weight name mapping if provided + if weight_name_mapping and name in weight_name_mapping: + name = weight_name_mapping[name] + + layer_id = get_layer_id(name) + if ( + layer_id is not None + and hasattr(self.model, "start_layer") + and ( + layer_id < self.model.start_layer + or layer_id >= self.model.end_layer + ) + ): + continue + + if "rotary_emb.inv_freq" in name: + continue + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + if "mlp.experts" in name: + continue + + name = name.replace(weight_name, param_name) + if name.endswith(".bias") and name not in params_dict: + continue + if name not in params_dict: + continue + + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + params_checker[name] = True + break + else: + for mapping in expert_params_mapping: + param_name, weight_name, shard_id = mapping + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + if name not in params_dict: + continue + param = params_dict[name] + weight_loader = param.weight_loader + if "bias" not in name: + loaded_weight = loaded_weight.transpose(-2, -1) + if "w2_weight_bias" in name and get_moe_tensor_parallel_rank() != 0: + loaded_weight = loaded_weight.zero_() + + weight_loader( + param, + loaded_weight, + name, + shard_id=shard_id, + ) + params_checker[name] = True + break + else: + if name.endswith(".bias") and name not in params_dict: + continue + if name not in params_dict: + continue + if name in params_dict.keys(): + param = params_dict[name] + if "sinks" in name: + start = tp_rank * param.numel() + param.data.copy_( + loaded_weight[start : start + param.numel()] + ) + else: + weight_loader = getattr( + param, "weight_loader", default_weight_loader + ) + weight_loader(param, loaded_weight) + params_checker[name] = True + else: + logger.warning(f"Parameter {name} not found in params_dict") + + not_loaded_params = [k for k, v in params_checker.items() if not v] + if tp_rank == 0: + if len(not_loaded_params) > 0: + raise Exception(f"Not all parameters loaded: {not_loaded_params}") + else: + logging.info("All parameters loaded successfully.") + + self.routed_experts_weights_of_layer = { + layer_id: self.model.layers[layer_id].mlp.get_moe_weights() + for layer_id in range(self.start_layer, self.end_layer) + if isinstance(self.model.layers[layer_id].mlp, GptOssSparseMoeBlock) + } + + def get_embed_and_head(self): + return self.model.embed_tokens.weight, self.lm_head.weight + + def set_embed_and_head(self, embed, head): + del self.model.embed_tokens.weight + del self.lm_head.weight + self.model.embed_tokens.weight = embed + self.lm_head.weight = head + torch.cuda.empty_cache() + torch.cuda.synchronize() + + def set_eagle3_layers_to_capture(self, layer_ids: Optional[List[int]] = None): + if not self.pp_group.is_last_rank: + return + + if layer_ids is None: + self.capture_aux_hidden_states = True + num_layers = self.config.num_hidden_layers + self.model.layers_to_capture = [2, num_layers // 2, num_layers - 3] + else: + self.capture_aux_hidden_states = True + # we plus 1 here because in sglang, for the ith layer, it takes the output + # of the (i-1)th layer as aux hidden state + self.model.layers_to_capture = [val + 1 for val in layer_ids] + + @classmethod + def get_model_config_for_expert_location(cls, config): + return ModelConfigForExpertLocation( + num_layers=config.num_hidden_layers, + num_logical_experts=config.num_local_experts, + num_groups=None, + ) + + def get_attention_sliding_window_size(self): + return get_attention_sliding_window_size(self.config) + + +def _canonicalize_weights(config, weights_in: Iterable[Tuple[str, torch.Tensor]]): + weights_out_dict = dict(weights_in) + + for layer_id in range(config.num_hidden_layers): + for name_chunk in ["mlp1_weight", "mlp2_weight"]: + name_prefix = f"block.{layer_id}.mlp.{name_chunk}" + w_blocks = weights_out_dict.pop(f"{name_prefix}.blocks", None) + w_scales = weights_out_dict.pop(f"{name_prefix}.scales", None) + if w_blocks is not None: + weights_out_dict[name_prefix] = _WeightCreator( + partial( + _dequant_mlp_weight, + debug_name=name_prefix, + w_blocks=w_blocks, + w_scales=w_scales, + ) + ) + + return list(weights_out_dict.items()) + + +def _dequant_mlp_weight(debug_name, w_blocks, w_scales): + if get_tensor_model_parallel_rank() == 0: + logger.info(f"Dequantize {debug_name} start") + + original_device = w_blocks.device + + w_blocks = w_blocks.cuda() + w_scales = w_scales.cuda() + + w_bf16 = dequant_mxfp4(w_block=w_blocks, w_scale=w_scales, out_dtype=torch.bfloat16) + w_bf16 = w_bf16.transpose(-2, -1).contiguous() + + if get_tensor_model_parallel_rank() == 0: + logger.info( + f"Dequantize {debug_name} end {w_blocks.shape=} {w_scales.shape=} {w_bf16.shape=}" + ) + + return w_bf16.to(original_device) + + +class _WeightCreator: + def __init__(self, fn): + self._fn = fn + + @staticmethod + def maybe_materialize(obj): + if isinstance(obj, _WeightCreator): + output = obj._fn() + obj._fn = None + return output + + return obj + + +EntryClass = GptOssForCausalLM diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 6c4a818ae..0d64571c1 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -457,6 +457,10 @@ class ServerArgs: raise ValueError( "trtllm_mla backend does not support speculative decoding yet." ) + model_arch = self.get_hf_config().architectures[0] + if model_arch in ["GptOssForCausalLM"]: + self.attention_backend = "triton" + self.enable_triton_kernel_moe = True # Set page size if self.page_size is None: