diff --git a/docs/backend/server_arguments.md b/docs/backend/server_arguments.md index e4dc1b878..bdd708cd0 100644 --- a/docs/backend/server_arguments.md +++ b/docs/backend/server_arguments.md @@ -175,7 +175,6 @@ Please consult the documentation below to learn more about the parameters you ma * `disable_cuda_graph_padding`: Disable cuda graph when padding is needed. In other case still use cuda graph. * `disable_outlines_disk_cache`: Disable disk cache for outlines grammar backend. * `disable_custom_all_reduce`: Disable usage of custom all reduce kernel. -* `disable_mla`: Disable [Multi-Head Latent Attention](https://arxiv.org/html/2405.04434v5) for Deepseek model. * `disable_overlap_schedule`: Disable the [Overhead-Scheduler](https://lmsys.org/blog/2024-12-04-sglang-v0-4/#zero-overhead-batch-scheduler). * `enable_nan_detection`: Turning this on makes the sampler print a warning if the logits contain `NaN`. * `enable_p2p_check`: Turns off the default of allowing always p2p check when accessing GPU. diff --git a/docs/references/deepseek.md b/docs/references/deepseek.md index 8fc71e03c..91d6d78fa 100644 --- a/docs/references/deepseek.md +++ b/docs/references/deepseek.md @@ -100,7 +100,7 @@ Overall, with these optimizations, we have achieved up to **7x** acceleration in Multi-head Latent Attention for DeepSeek Series Models

-**Usage**: MLA optimization is enabled by default. To disable MLA usage, use `--disable-mla`. To disable chunked prefix cache feature for mla, use `disable-chunked-prefix-cache`. +**Usage**: MLA optimization is enabled by default. To disable chunked prefix cache feature for mla, use `disable-chunked-prefix-cache`. **Reference**: Check [Blog](https://lmsys.org/blog/2024-09-04-sglang-v0-3/#deepseek-multi-head-latent-attention-mla-throughput-optimizations) and [Slides](https://github.com/sgl-project/sgl-learning-materials/blob/main/slides/lmsys_1st_meetup_deepseek_mla.pdf) for more details. diff --git a/python/sglang/srt/layers/attention/flashattention_backend.py b/python/sglang/srt/layers/attention/flashattention_backend.py index 26e54546d..cee8ae4c8 100644 --- a/python/sglang/srt/layers/attention/flashattention_backend.py +++ b/python/sglang/srt/layers/attention/flashattention_backend.py @@ -299,9 +299,7 @@ class FlashAttentionBackend(AttentionBackend): self.kv_cache_dtype = model_runner.kv_cache_dtype self.kv_cache_dtype_str = model_runner.server_args.kv_cache_dtype self.page_size = model_runner.page_size - self.use_mla = ( - model_runner.model_config.attention_arch == AttentionArch.MLA - ) and (not global_server_args_dict["disable_mla"]) + self.use_mla = model_runner.model_config.attention_arch == AttentionArch.MLA self.skip_prefill = skip_prefill self.topk = topk diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 6bfd69307..7d3afa824 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -67,7 +67,6 @@ global_server_args_dict = { "attention_backend": ServerArgs.attention_backend, "sampling_backend": ServerArgs.sampling_backend, "triton_attention_reduce_in_fp32": ServerArgs.triton_attention_reduce_in_fp32, - "disable_mla": ServerArgs.disable_mla, "torchao_config": ServerArgs.torchao_config, "enable_nan_detection": ServerArgs.enable_nan_detection, "enable_dp_attention": ServerArgs.enable_dp_attention, diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index ed899f080..df443c599 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -127,10 +127,7 @@ class ModelRunner: self.page_size = server_args.page_size self.req_to_token_pool = req_to_token_pool self.token_to_kv_pool_allocator = token_to_kv_pool_allocator - self.use_mla_backend = ( - self.model_config.attention_arch == AttentionArch.MLA - and not server_args.disable_mla - ) + self.use_mla_backend = self.model_config.attention_arch == AttentionArch.MLA self.attention_chunk_size = model_config.attention_chunk_size # Model-specific adjustment @@ -150,7 +147,6 @@ class ModelRunner: "attention_backend": server_args.attention_backend, "sampling_backend": server_args.sampling_backend, "triton_attention_reduce_in_fp32": server_args.triton_attention_reduce_in_fp32, - "disable_mla": server_args.disable_mla, "torchao_config": server_args.torchao_config, "enable_nan_detection": server_args.enable_nan_detection, "enable_dp_attention": server_args.enable_dp_attention, diff --git a/python/sglang/srt/models/deepseek_nextn.py b/python/sglang/srt/models/deepseek_nextn.py index a1739ec78..645b13d74 100644 --- a/python/sglang/srt/models/deepseek_nextn.py +++ b/python/sglang/srt/models/deepseek_nextn.py @@ -262,79 +262,75 @@ class DeepseekV3ForCausalLMNextN(DeepseekV3ForCausalLM): ) weight_loader(param, loaded_weight) - if not global_server_args_dict["disable_mla"]: - self_attn = self.model.decoder.self_attn - if hasattr(self_attn.kv_b_proj, "qweight"): - # AWQ compatible - if _is_cuda: - w = awq_dequantize( - self_attn.kv_b_proj.qweight, - self_attn.kv_b_proj.scales, - self_attn.kv_b_proj.qzeros, - ).T - else: - w = awq_dequantize( - self_attn.kv_b_proj.qweight, - self_attn.kv_b_proj.scales, - self_attn.kv_b_proj.qzeros, - 0, - 0, - 0, - ).T + self_attn = self.model.decoder.self_attn + if hasattr(self_attn.kv_b_proj, "qweight"): + # AWQ compatible + if _is_cuda: + w = awq_dequantize( + self_attn.kv_b_proj.qweight, + self_attn.kv_b_proj.scales, + self_attn.kv_b_proj.qzeros, + ).T else: - w = self_attn.kv_b_proj.weight - # NOTE(HandH1998): Since `bmm_fp8` only supports per-tensor scale, we have to requantize `self_attn.kv_b_proj`. - # This may affect the accuracy of fp8 model. - if hasattr(self.quant_config, "weight_block_size") and w.dtype in ( - torch.float8_e4m3fn, - torch.float8_e4m3fnuz, - ): + w = awq_dequantize( + self_attn.kv_b_proj.qweight, + self_attn.kv_b_proj.scales, + self_attn.kv_b_proj.qzeros, + 0, + 0, + 0, + ).T + else: + w = self_attn.kv_b_proj.weight + # NOTE(HandH1998): Since `bmm_fp8` only supports per-tensor scale, we have to requantize `self_attn.kv_b_proj`. + # This may affect the accuracy of fp8 model. + if hasattr(self.quant_config, "weight_block_size") and w.dtype in ( + torch.float8_e4m3fn, + torch.float8_e4m3fnuz, + ): + weight_block_size = self.quant_config.weight_block_size + if weight_block_size is not None: + assert hasattr(self_attn.kv_b_proj, "weight_scale_inv") + if _is_hip: + weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz( + weight=w, + weight_scale=self_attn.kv_b_proj.weight_scale_inv, + input_scale=None, + ) + else: + weight = w + weight_scale = self_attn.kv_b_proj.weight_scale_inv + + w, scale = block_quant_to_tensor_quant( + weight, weight_scale, weight_block_size + ) + self_attn.w_scale = scale + if w.dtype == torch.int8: + if hasattr(self.quant_config, "weight_block_size"): + # block-wise int8 need it weight_block_size = self.quant_config.weight_block_size if weight_block_size is not None: assert hasattr(self_attn.kv_b_proj, "weight_scale_inv") - if _is_hip: - weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz( - weight=w, - weight_scale=self_attn.kv_b_proj.weight_scale_inv, - input_scale=None, - ) - else: - weight = w - weight_scale = self_attn.kv_b_proj.weight_scale_inv - - w, scale = block_quant_to_tensor_quant( - weight, weight_scale, weight_block_size - ) - self_attn.w_scale = scale - if w.dtype == torch.int8: - if hasattr(self.quant_config, "weight_block_size"): - # block-wise int8 need it - weight_block_size = self.quant_config.weight_block_size - if weight_block_size is not None: - assert hasattr(self_attn.kv_b_proj, "weight_scale_inv") - weight = w - weight_scale = self_attn.kv_b_proj.weight_scale_inv - w = int8_block_dequant( - weight, weight_scale, weight_block_size - ).to(torch.bfloat16) - else: - # channel-wise int8 need it - assert hasattr(self_attn.kv_b_proj, "weight_scale") - w = w.to(torch.bfloat16) * self_attn.kv_b_proj.weight_scale.to( + weight = w + weight_scale = self_attn.kv_b_proj.weight_scale_inv + w = int8_block_dequant(weight, weight_scale, weight_block_size).to( torch.bfloat16 ) - w_kc, w_vc = w.unflatten( - 0, (-1, self_attn.qk_nope_head_dim + self_attn.v_head_dim) - ).split([self_attn.qk_nope_head_dim, self_attn.v_head_dim], dim=1) - self_attn.w_kc = w_kc.transpose(1, 2).contiguous().transpose(1, 2) - self_attn.w_vc = w_vc.contiguous().transpose(1, 2) - if ( - hasattr(self_attn.kv_b_proj, "weight_scale") - and self_attn.w_scale is None - ): - self_attn.w_scale = self_attn.kv_b_proj.weight_scale - if _is_hip: - self_attn.w_scale *= 2.0 + else: + # channel-wise int8 need it + assert hasattr(self_attn.kv_b_proj, "weight_scale") + w = w.to(torch.bfloat16) * self_attn.kv_b_proj.weight_scale.to( + torch.bfloat16 + ) + w_kc, w_vc = w.unflatten( + 0, (-1, self_attn.qk_nope_head_dim + self_attn.v_head_dim) + ).split([self_attn.qk_nope_head_dim, self_attn.v_head_dim], dim=1) + self_attn.w_kc = w_kc.transpose(1, 2).contiguous().transpose(1, 2) + self_attn.w_vc = w_vc.contiguous().transpose(1, 2) + if hasattr(self_attn.kv_b_proj, "weight_scale") and self_attn.w_scale is None: + self_attn.w_scale = self_attn.kv_b_proj.weight_scale + if _is_hip: + self_attn.w_scale *= 2.0 EntryClass = [DeepseekV3ForCausalLMNextN] diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index abb9aa4bb..22831a310 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -386,179 +386,6 @@ def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float: return 0.1 * mscale * math.log(scale) + 1.0 -class DeepseekV2Attention(nn.Module): - - def __init__( - self, - config: PretrainedConfig, - hidden_size: int, - num_heads: int, - qk_nope_head_dim: int, - qk_rope_head_dim: int, - v_head_dim: int, - q_lora_rank: int, - kv_lora_rank: int, - rope_theta: float = 10000, - rope_scaling: Optional[Dict[str, Any]] = None, - max_position_embeddings: int = 8192, - quant_config: Optional[QuantizationConfig] = None, - layer_id=None, - reduce_results: bool = True, - prefix: str = "", - ) -> None: - super().__init__() - self.layer_id = layer_id - self.hidden_size = hidden_size - self.qk_nope_head_dim = qk_nope_head_dim - self.qk_rope_head_dim = qk_rope_head_dim - self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim - self.v_head_dim = v_head_dim - self.q_lora_rank = q_lora_rank - self.kv_lora_rank = kv_lora_rank - - self.dp_size = get_attention_dp_size() - attn_tp_rank = get_attention_tp_rank() - attn_tp_size = get_attention_tp_size() - - self.num_heads = num_heads - assert num_heads % attn_tp_size == 0 - self.num_local_heads = num_heads // attn_tp_size - self.scaling = self.qk_head_dim**-0.5 - self.rope_theta = rope_theta - self.max_position_embeddings = max_position_embeddings - - if self.q_lora_rank is not None: - self.q_a_proj = ReplicatedLinear( - self.hidden_size, - self.q_lora_rank, - bias=False, - quant_config=quant_config, - prefix=add_prefix("q_a_proj", prefix), - ) - self.q_a_layernorm = RMSNorm(self.q_lora_rank, eps=config.rms_norm_eps) - self.q_b_proj = ColumnParallelLinear( - q_lora_rank, - self.num_heads * self.qk_head_dim, - bias=False, - quant_config=quant_config, - prefix=add_prefix("q_b_proj", prefix), - ) - else: - self.q_proj = ColumnParallelLinear( - self.hidden_size, - self.num_heads * self.qk_head_dim, - bias=False, - quant_config=quant_config, - prefix=add_prefix("q_proj", prefix), - tp_rank=attn_tp_rank, - tp_size=attn_tp_size, - ) - - self.kv_a_proj_with_mqa = ReplicatedLinear( - self.hidden_size, - self.kv_lora_rank + self.qk_rope_head_dim, - bias=False, - quant_config=quant_config, - prefix=add_prefix("kv_a_proj_with_mqa", prefix), - ) - self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, eps=config.rms_norm_eps) - self.kv_b_proj = ColumnParallelLinear( - self.kv_lora_rank, - self.num_heads * (self.qk_nope_head_dim + self.v_head_dim), - bias=False, - quant_config=quant_config, - prefix=add_prefix("kv_b_proj", prefix), - ) - # O projection. - self.o_proj = RowParallelLinear( - self.num_heads * self.v_head_dim, - self.hidden_size, - bias=False, - quant_config=quant_config, - prefix=add_prefix("o_proj", prefix), - reduce_results=reduce_results, - tp_rank=attn_tp_rank, - tp_size=attn_tp_size, - ) - rope_scaling["rope_type"] = "deepseek_yarn" - self.rotary_emb = get_rope_wrapper( - qk_rope_head_dim, - rotary_dim=qk_rope_head_dim, - max_position=max_position_embeddings, - base=rope_theta, - rope_scaling=rope_scaling, - is_neox_style=False, - device=global_server_args_dict["device"], - ) - - if rope_scaling: - mscale_all_dim = rope_scaling.get("mscale_all_dim", False) - scaling_factor = rope_scaling["factor"] - mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim)) - self.scaling = self.scaling * mscale * mscale - - # TODO, support head_size 192 - self.attn = RadixAttention( - self.num_local_heads, - 256, - self.scaling, - num_kv_heads=self.num_local_heads, - layer_id=layer_id, - quant_config=quant_config, - prefix=add_prefix("attn", prefix), - ) - - def forward( - self, - positions: torch.Tensor, - hidden_states: torch.Tensor, - forward_batch: ForwardBatch, - ) -> torch.Tensor: - if hidden_states.shape[0] == 0: - assert ( - not self.o_proj.reduce_results - ), "short-circuiting allreduce will lead to hangs" - return hidden_states - - if self.q_lora_rank is not None: - q = self.q_a_proj(hidden_states)[0] - q = self.q_a_layernorm(q) - q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, self.qk_head_dim) - else: - q = self.q_proj(hidden_states)[0].view( - -1, self.num_local_heads, self.qk_head_dim - ) - _, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) - latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0] - kv_a, _ = latent_cache.split([self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) - latent_cache = latent_cache.unsqueeze(1) - kv_a = self.kv_a_layernorm(kv_a.contiguous()) - kv = self.kv_b_proj(kv_a)[0] - kv = kv.view(-1, self.num_local_heads, self.qk_nope_head_dim + self.v_head_dim) - k_nope, v = kv.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) - k_pe = latent_cache[:, :, self.kv_lora_rank :] - q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe) - q[..., self.qk_nope_head_dim :] = q_pe - k = torch.empty_like(q) - k[..., : self.qk_nope_head_dim] = k_nope - k[..., self.qk_nope_head_dim :] = k_pe - q = torch.nn.functional.pad(q, [0, 256 - self.qk_head_dim], value=0).view( - -1, self.num_local_heads * 256 - ) - k = torch.nn.functional.pad(k, [0, 256 - self.qk_head_dim], value=0).view( - -1, self.num_local_heads * 256 - ) - v = torch.nn.functional.pad(v, [0, 256 - self.v_head_dim], value=0).view( - -1, self.num_local_heads * 256 - ) - attn_output = self.attn(q, k, v, forward_batch) - attn_output = attn_output.view(-1, self.num_local_heads, 256)[ - ..., : self.v_head_dim - ].reshape(-1, self.num_local_heads * self.v_head_dim) - output, _ = self.o_proj(attn_output) - return output - - class DeepseekV2AttentionMLA(nn.Module): def __init__( @@ -1200,47 +1027,25 @@ class DeepseekV2DecoderLayer(nn.Module): self.dp_size = get_attention_dp_size() self.attn_tp_size = get_attention_tp_size() self.attn_tp_rank = get_attention_tp_rank() - - if not global_server_args_dict["disable_mla"]: - self.self_attn = DeepseekV2AttentionMLA( - config=config, - hidden_size=self.hidden_size, - num_heads=config.num_attention_heads, - qk_nope_head_dim=config.qk_nope_head_dim, - qk_rope_head_dim=config.qk_rope_head_dim, - v_head_dim=config.v_head_dim, - q_lora_rank=( - config.q_lora_rank if hasattr(config, "q_lora_rank") else None - ), - kv_lora_rank=config.kv_lora_rank, - rope_theta=rope_theta, - rope_scaling=rope_scaling, - max_position_embeddings=max_position_embeddings, - quant_config=quant_config, - layer_id=layer_id, - reduce_results=False, - prefix=add_prefix("self_attn", prefix), - ) - else: - self.self_attn = DeepseekV2Attention( - config=config, - hidden_size=self.hidden_size, - num_heads=config.num_attention_heads, - qk_nope_head_dim=config.qk_nope_head_dim, - qk_rope_head_dim=config.qk_rope_head_dim, - v_head_dim=config.v_head_dim, - q_lora_rank=( - config.q_lora_rank if hasattr(config, "q_lora_rank") else None - ), - kv_lora_rank=config.kv_lora_rank, - rope_theta=rope_theta, - rope_scaling=rope_scaling, - max_position_embeddings=max_position_embeddings, - quant_config=quant_config, - layer_id=layer_id, - reduce_results=False, - prefix=add_prefix("self_attn", prefix), - ) + self.self_attn = DeepseekV2AttentionMLA( + config=config, + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + qk_nope_head_dim=config.qk_nope_head_dim, + qk_rope_head_dim=config.qk_rope_head_dim, + v_head_dim=config.v_head_dim, + q_lora_rank=( + config.q_lora_rank if hasattr(config, "q_lora_rank") else None + ), + kv_lora_rank=config.kv_lora_rank, + rope_theta=rope_theta, + rope_scaling=rope_scaling, + max_position_embeddings=max_position_embeddings, + quant_config=quant_config, + layer_id=layer_id, + reduce_results=False, + prefix=add_prefix("self_attn", prefix), + ) if is_nextn or is_sparse_layer(layer_id): self.mlp = DeepseekV2MoE( @@ -1551,87 +1356,85 @@ class DeepseekV2ForCausalLM(nn.Module): def post_load_weights(self): # Perform post-processing after loading weights - - if not global_server_args_dict["disable_mla"]: - for layer_id in range(self.config.num_hidden_layers): - self_attn = self.model.layers[layer_id].self_attn - if hasattr(self_attn.kv_b_proj, "qweight"): - # AWQ compatible - if _is_cuda: - w = awq_dequantize( - self_attn.kv_b_proj.qweight, - self_attn.kv_b_proj.scales, - self_attn.kv_b_proj.qzeros, - ).T - else: - w = awq_dequantize( - self_attn.kv_b_proj.qweight, - self_attn.kv_b_proj.scales, - self_attn.kv_b_proj.qzeros, - 0, - 0, - 0, - ).T + for layer_id in range(self.config.num_hidden_layers): + self_attn = self.model.layers[layer_id].self_attn + if hasattr(self_attn.kv_b_proj, "qweight"): + # AWQ compatible + if _is_cuda: + w = awq_dequantize( + self_attn.kv_b_proj.qweight, + self_attn.kv_b_proj.scales, + self_attn.kv_b_proj.qzeros, + ).T else: - w = self_attn.kv_b_proj.weight - # NOTE(HandH1998): Since `bmm_fp8` only supports per-tensor scale, we have to requantize `self_attn.kv_b_proj`. - # This may affect the accuracy of fp8 model. - if w.dtype in ( - torch.float8_e4m3fn, - torch.float8_e4m3fnuz, - ): - if hasattr(self.quant_config, "weight_block_size"): - weight_block_size = self.quant_config.weight_block_size - if weight_block_size is not None: - assert hasattr(self_attn.kv_b_proj, "weight_scale_inv") - if _is_hip: - weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz( - weight=w, - weight_scale=self_attn.kv_b_proj.weight_scale_inv, - input_scale=None, - ) - else: - weight = w - weight_scale = self_attn.kv_b_proj.weight_scale_inv - - w, scale = block_quant_to_tensor_quant( - weight, weight_scale, weight_block_size + w = awq_dequantize( + self_attn.kv_b_proj.qweight, + self_attn.kv_b_proj.scales, + self_attn.kv_b_proj.qzeros, + 0, + 0, + 0, + ).T + else: + w = self_attn.kv_b_proj.weight + # NOTE(HandH1998): Since `bmm_fp8` only supports per-tensor scale, we have to requantize `self_attn.kv_b_proj`. + # This may affect the accuracy of fp8 model. + if w.dtype in ( + torch.float8_e4m3fn, + torch.float8_e4m3fnuz, + ): + if hasattr(self.quant_config, "weight_block_size"): + weight_block_size = self.quant_config.weight_block_size + if weight_block_size is not None: + assert hasattr(self_attn.kv_b_proj, "weight_scale_inv") + if _is_hip: + weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz( + weight=w, + weight_scale=self_attn.kv_b_proj.weight_scale_inv, + input_scale=None, ) - self_attn.w_scale = scale - else: - weight = w - weight_scale = self_attn.kv_b_proj.weight_scale - w, scale = channel_quant_to_tensor_quant(weight, weight_scale) - self_attn.w_scale = scale - - if w.dtype == torch.int8: - if hasattr(self.quant_config, "weight_block_size"): - # block-wise int8 need it - weight_block_size = self.quant_config.weight_block_size - if weight_block_size is not None: - assert hasattr(self_attn.kv_b_proj, "weight_scale_inv") + else: weight = w weight_scale = self_attn.kv_b_proj.weight_scale_inv - w = int8_block_dequant( - weight, weight_scale, weight_block_size - ).to(torch.bfloat16) - else: - # channel-wise int8 need it - w = w.to(torch.bfloat16) * self_attn.kv_b_proj.weight_scale.to( - torch.bfloat16 + + w, scale = block_quant_to_tensor_quant( + weight, weight_scale, weight_block_size ) - w_kc, w_vc = w.unflatten( - 0, (-1, self_attn.qk_nope_head_dim + self_attn.v_head_dim) - ).split([self_attn.qk_nope_head_dim, self_attn.v_head_dim], dim=1) - self_attn.w_kc = w_kc.transpose(1, 2).contiguous().transpose(1, 2) - self_attn.w_vc = w_vc.contiguous().transpose(1, 2) - if ( - hasattr(self_attn.kv_b_proj, "weight_scale") - and self_attn.w_scale is None - ): - self_attn.w_scale = self_attn.kv_b_proj.weight_scale - if _is_hip: - self_attn.w_scale *= 2.0 + self_attn.w_scale = scale + else: + weight = w + weight_scale = self_attn.kv_b_proj.weight_scale + w, scale = channel_quant_to_tensor_quant(weight, weight_scale) + self_attn.w_scale = scale + + if w.dtype == torch.int8: + if hasattr(self.quant_config, "weight_block_size"): + # block-wise int8 need it + weight_block_size = self.quant_config.weight_block_size + if weight_block_size is not None: + assert hasattr(self_attn.kv_b_proj, "weight_scale_inv") + weight = w + weight_scale = self_attn.kv_b_proj.weight_scale_inv + w = int8_block_dequant( + weight, weight_scale, weight_block_size + ).to(torch.bfloat16) + else: + # channel-wise int8 need it + w = w.to(torch.bfloat16) * self_attn.kv_b_proj.weight_scale.to( + torch.bfloat16 + ) + w_kc, w_vc = w.unflatten( + 0, (-1, self_attn.qk_nope_head_dim + self_attn.v_head_dim) + ).split([self_attn.qk_nope_head_dim, self_attn.v_head_dim], dim=1) + self_attn.w_kc = w_kc.transpose(1, 2).contiguous().transpose(1, 2) + self_attn.w_vc = w_vc.contiguous().transpose(1, 2) + if ( + hasattr(self_attn.kv_b_proj, "weight_scale") + and self_attn.w_scale is None + ): + self_attn.w_scale = self_attn.kv_b_proj.weight_scale + if _is_hip: + self_attn.w_scale *= 2.0 def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ diff --git a/python/sglang/srt/models/minicpm3.py b/python/sglang/srt/models/minicpm3.py index eae2bf007..c7053f78f 100644 --- a/python/sglang/srt/models/minicpm3.py +++ b/python/sglang/srt/models/minicpm3.py @@ -93,158 +93,6 @@ def input_to_float8(x, dtype=torch.float8_e4m3fn): return x_scl_sat.to(dtype).contiguous(), scale.float().reciprocal() -class MiniCPM3Attention(nn.Module): - - def __init__( - self, - config: PretrainedConfig, - hidden_size: int, - num_heads: int, - qk_nope_head_dim: int, - qk_rope_head_dim: int, - v_head_dim: int, - q_lora_rank: int, - kv_lora_rank: int, - rope_theta: float = 10000, - rope_scaling: Optional[Dict[str, Any]] = None, - max_position_embeddings: int = 8192, - quant_config: Optional[QuantizationConfig] = None, - layer_id=None, - prefix: str = "", - ) -> None: - super().__init__() - self.layer_id = layer_id - self.hidden_size = hidden_size - self.qk_nope_head_dim = qk_nope_head_dim - self.qk_rope_head_dim = qk_rope_head_dim - self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim - self.v_head_dim = v_head_dim - self.q_lora_rank = q_lora_rank - self.kv_lora_rank = kv_lora_rank - self.num_heads = num_heads - tp_size = get_tensor_model_parallel_world_size() - assert num_heads % tp_size == 0 - self.num_local_heads = num_heads // tp_size - self.scaling = self.qk_head_dim**-0.5 - self.rope_theta = rope_theta - self.max_position_embeddings = max_position_embeddings - - if self.q_lora_rank is not None: - self.q_a_proj = ReplicatedLinear( - self.hidden_size, - self.q_lora_rank, - bias=False, - quant_config=quant_config, - prefix=add_prefix("q_a_proj", prefix), - ) - self.q_a_layernorm = RMSNorm(self.q_lora_rank, eps=config.rms_norm_eps) - self.q_b_proj = ColumnParallelLinear( - q_lora_rank, - self.num_heads * self.qk_head_dim, - bias=False, - quant_config=quant_config, - prefix=add_prefix("q_b_proj", prefix), - ) - else: - self.q_proj = ColumnParallelLinear( - self.hidden_size, - self.num_heads * self.qk_head_dim, - bias=False, - quant_config=quant_config, - prefix=add_prefix("q_proj", prefix), - ) - - self.kv_a_proj_with_mqa = ReplicatedLinear( - self.hidden_size, - self.kv_lora_rank + self.qk_rope_head_dim, - bias=False, - quant_config=quant_config, - prefix=add_prefix("kv_a_proj_with_mqa", prefix), - ) - self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, eps=config.rms_norm_eps) - self.kv_b_proj = ColumnParallelLinear( - self.kv_lora_rank, - self.num_heads * (self.qk_nope_head_dim + self.v_head_dim), - bias=False, - quant_config=quant_config, - prefix=add_prefix("kv_b_proj", prefix), - ) - # O projection. - self.o_proj = RowParallelLinear( - self.num_heads * self.v_head_dim, - self.hidden_size, - bias=False, - quant_config=quant_config, - prefix=add_prefix("o_proj", prefix), - ) - self.rotary_emb = get_rope( - qk_rope_head_dim, - rotary_dim=qk_rope_head_dim, - max_position=max_position_embeddings, - base=rope_theta, - rope_scaling=rope_scaling, - ) - - # TODO support head_size 96 - self.attn = RadixAttention( - self.num_local_heads, - 128, - self.scaling, - num_kv_heads=self.num_local_heads, - layer_id=layer_id, - quant_config=quant_config, - prefix=add_prefix("attn", prefix), - ) - - def forward( - self, - positions: torch.Tensor, - hidden_states: torch.Tensor, - forward_batch: ForwardBatch, - ) -> torch.Tensor: - if self.q_lora_rank is not None: - q = self.q_a_proj(hidden_states)[0] - q = self.q_a_layernorm(q) - q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, self.qk_head_dim) - else: - q = self.q_proj(hidden_states)[0].view( - -1, self.num_local_heads, self.qk_head_dim - ) - _, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) - latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0] - kv_a, _ = latent_cache.split([self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) - latent_cache = latent_cache.unsqueeze(1) - kv_a = self.kv_a_layernorm(kv_a.contiguous()) - kv = self.kv_b_proj(kv_a)[0] - kv = kv.view(-1, self.num_local_heads, self.qk_nope_head_dim + self.v_head_dim) - k_nope, v = kv.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) - k_pe = latent_cache[:, :, self.kv_lora_rank :] - original_shapes = [q_pe.shape, k_pe.shape] - q_pe, k_pe = self.rotary_emb( - positions, q_pe.reshape(q_pe.shape[0], -1), k_pe.reshape(k_pe.shape[0], -1) - ) - q_pe, k_pe = q_pe.view(original_shapes[0]), k_pe.view(original_shapes[1]) - q[..., self.qk_nope_head_dim :] = q_pe - k = torch.empty_like(q) - k[..., : self.qk_nope_head_dim] = k_nope - k[..., self.qk_nope_head_dim :] = k_pe - q = torch.nn.functional.pad(q, [0, 128 - self.qk_head_dim], value=0).view( - -1, self.num_local_heads * 128 - ) - k = torch.nn.functional.pad(k, [0, 128 - self.qk_head_dim], value=0).view( - -1, self.num_local_heads * 128 - ) - v = torch.nn.functional.pad(v, [0, 128 - self.v_head_dim], value=0).view( - -1, self.num_local_heads * 128 - ) - attn_output = self.attn(q, k, v, forward_batch) - attn_output = attn_output.view(-1, self.num_local_heads, 128)[ - ..., : self.v_head_dim - ].reshape(-1, self.num_local_heads * self.v_head_dim) - output, _ = self.o_proj(attn_output) - return output - - class MiniCPM3AttentionMLA(nn.Module): def __init__( @@ -434,44 +282,25 @@ class MiniCPM3DecoderLayer(nn.Module): rope_theta = getattr(config, "rope_theta", 10000) rope_scaling = getattr(config, "rope_scaling", None) max_position_embeddings = getattr(config, "max_position_embeddings", 8192) - if not global_server_args_dict["disable_mla"]: - self.self_attn = MiniCPM3AttentionMLA( - config=config, - hidden_size=self.hidden_size, - num_heads=config.num_attention_heads, - qk_nope_head_dim=config.qk_nope_head_dim, - qk_rope_head_dim=config.qk_rope_head_dim, - v_head_dim=self.hidden_size // config.num_attention_heads, - q_lora_rank=( - config.q_lora_rank if hasattr(config, "q_lora_rank") else None - ), - kv_lora_rank=config.kv_lora_rank, - rope_theta=rope_theta, - rope_scaling=rope_scaling, - max_position_embeddings=max_position_embeddings, - quant_config=quant_config, - layer_id=layer_id, - prefix=add_prefix("self_attn", prefix), - ) - else: - self.self_attn = MiniCPM3Attention( - config=config, - hidden_size=self.hidden_size, - num_heads=config.num_attention_heads, - qk_nope_head_dim=config.qk_nope_head_dim, - qk_rope_head_dim=config.qk_rope_head_dim, - v_head_dim=self.hidden_size // config.num_attention_heads, - q_lora_rank=( - config.q_lora_rank if hasattr(config, "q_lora_rank") else None - ), - kv_lora_rank=config.kv_lora_rank, - rope_theta=rope_theta, - rope_scaling=rope_scaling, - max_position_embeddings=max_position_embeddings, - quant_config=quant_config, - layer_id=layer_id, - prefix=add_prefix("self_attn", prefix), - ) + self.self_attn = MiniCPM3AttentionMLA( + config=config, + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + qk_nope_head_dim=config.qk_nope_head_dim, + qk_rope_head_dim=config.qk_rope_head_dim, + v_head_dim=self.hidden_size // config.num_attention_heads, + q_lora_rank=( + config.q_lora_rank if hasattr(config, "q_lora_rank") else None + ), + kv_lora_rank=config.kv_lora_rank, + rope_theta=rope_theta, + rope_scaling=rope_scaling, + max_position_embeddings=max_position_embeddings, + quant_config=quant_config, + layer_id=layer_id, + prefix=add_prefix("self_attn", prefix), + ) + self.mlp = MiniCPM3MLP( hidden_size=self.hidden_size, intermediate_size=config.intermediate_size, @@ -674,17 +503,16 @@ class MiniCPM3ForCausalLM(nn.Module): ) weight_loader(param, loaded_weight) - if not global_server_args_dict["disable_mla"]: - for layer_id in range(self.config.num_hidden_layers): - self_attn = self.model.layers[layer_id].self_attn - w_kc, w_vc = self_attn.kv_b_proj.weight.unflatten( - 0, (-1, self_attn.qk_nope_head_dim + self_attn.v_head_dim) - ).split([self_attn.qk_nope_head_dim, self_attn.v_head_dim], dim=1) - self_attn.w_kc = w_kc.transpose(1, 2).contiguous().transpose(1, 2) - self_attn.w_vc = w_vc.contiguous().transpose(1, 2) - if hasattr(self_attn.kv_b_proj, "weight_scale"): - self_attn.w_scale = self_attn.kv_b_proj.weight_scale - del self_attn.kv_b_proj + for layer_id in range(self.config.num_hidden_layers): + self_attn = self.model.layers[layer_id].self_attn + w_kc, w_vc = self_attn.kv_b_proj.weight.unflatten( + 0, (-1, self_attn.qk_nope_head_dim + self_attn.v_head_dim) + ).split([self_attn.qk_nope_head_dim, self_attn.v_head_dim], dim=1) + self_attn.w_kc = w_kc.transpose(1, 2).contiguous().transpose(1, 2) + self_attn.w_vc = w_vc.contiguous().transpose(1, 2) + if hasattr(self_attn.kv_b_proj, "weight_scale"): + self_attn.w_scale = self_attn.kv_b_proj.weight_scale + del self_attn.kv_b_proj EntryClass = MiniCPM3ForCausalLM diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index b45860aba..e172e66bb 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -155,7 +155,6 @@ class ServerArgs: enable_nccl_nvls: bool = False disable_outlines_disk_cache: bool = False disable_custom_all_reduce: bool = False - disable_mla: bool = False enable_llama4_multimodal: Optional[bool] = None disable_overlap_schedule: bool = False enable_mixed_chunk: bool = False @@ -975,11 +974,6 @@ class ServerArgs: action="store_true", help="Disable the custom all-reduce kernel and fall back to NCCL.", ) - parser.add_argument( - "--disable-mla", - action="store_true", - help="Disable Multi-head Latent Attention (MLA) for DeepSeek V2/V3/R1 series models.", - ) parser.add_argument( "--enable-llama4-multimodal", default=ServerArgs.enable_llama4_multimodal,