diff --git a/python/sglang/srt/models/hunyuan.py b/python/sglang/srt/models/hunyuan.py index 41a833f3d..f23ccc0a8 100644 --- a/python/sglang/srt/models/hunyuan.py +++ b/python/sglang/srt/models/hunyuan.py @@ -206,6 +206,42 @@ class HunYuanSparseMoeBlock(nn.Module): return final_hidden_states.view(orig_shape) +def get_head_dim(config): + if hasattr(config, "head_dim"): + return int(config.head_dim) + if hasattr(config, "attention_head_dim"): + return int(config.attention_head_dim) + + # since some hunyuan model don't follow the self.hidden_size // self.total_num_heads rule + # wrong setting may cause runtime error, just throw error if this field is missing. + raise ValueError("Missing head dim config, try set head_dim in config.json") + + +def check_head_dim(config): + # Some models may lack `head_dim` and use `attention_head_dim` instead. + # This attribute is also used by flashinfer_backend.py, so we check for + # consistency and raise an error if it's not met to avoid silent failures. + # Although we could adapt the HunYuan model to use `attention_head_dim`, + # flashinfer expects `head_dim`, so we enforce its presence for correctness. + calc_head_dim = config.hidden_size // config.num_attention_heads + + if hasattr(config, "attention_head_dim"): + if calc_head_dim != config.attention_head_dim and not hasattr( + config, "head_dim" + ): + # in this case, flash infer(and other components may calculate wrong value.) + raise ValueError( + f"HunYuan model config error: calculated head_dim {calc_head_dim} != attention_head_dim {config.attention_head_dim}" + + f"\nPlease Add head_dim:{config.attention_head_dim} in config.json to make sure correctly inference." + ) + + if hasattr(config, "head_dim") and config.attention_head_dim != config.head_dim: + raise ValueError( + f"HunYuan model config error: head_dim({config.head_dim}) != attention_head_dim({config.attention_head_dim})" + + f"\nPlease change head_dim:{config.attention_head_dim} in config.json to make sure correctly inference." + ) + + class HunYuanAttention(nn.Module): def __init__( @@ -240,9 +276,11 @@ class HunYuanAttention(nn.Module): assert tp_size % self.total_num_kv_heads == 0 self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) # MistralConfig has an optional head_dim introduced by Mistral-Nemo - self.head_dim = getattr( - config, "head_dim", self.hidden_size // self.total_num_heads - ) + # Prioritize `head_dim` but fall back to `attention_head_dim` for Hunyuan models. + self.head_dim = get_head_dim(config) + + check_head_dim(config) + self.q_size = self.num_heads * self.head_dim self.kv_size = self.num_kv_heads * self.head_dim self.scaling = self.head_dim**-0.5 @@ -493,7 +531,6 @@ class HunYuanModel(nn.Module): hidden_states = self.get_input_embeddings(input_ids) residual = None - cla_factor = _get_cla_factor(self.config) prev_kv_states = None for i in range(len(self.layers)): layer = self.layers[i] @@ -560,6 +597,11 @@ class HunYuanMoEV1ForCausalLM(nn.Module): if config.tie_word_embeddings: self.lm_head.weight = self.model.embed_tokens.weight + self.hidden_size = config.hidden_size + self.head_dim = get_head_dim(config) + + check_head_dim(config) + logit_scale = getattr(config, "logit_scale", 1.0) self.logits_processor = LogitsProcessor(config, logit_scale=logit_scale) self.sampler = Sampler() @@ -582,16 +624,14 @@ class HunYuanMoEV1ForCausalLM(nn.Module): self.config, "num_key_value_heads", self.config.num_attention_heads ) num_key_value_groups = num_attention_heads // num_kv_heads - hidden_size = self.config.hidden_size - attention_head_dim = self.config.hidden_size // num_attention_heads qkv = qkv.reshape( - num_kv_heads, num_key_value_groups + 2, attention_head_dim, hidden_size + num_kv_heads, num_key_value_groups + 2, self.head_dim, self.hidden_size ) q, k, v = torch.split(qkv, (num_key_value_groups, 1, 1), dim=1) - q = q.reshape(-1, hidden_size) - k = k.reshape(-1, hidden_size) - v = v.reshape(-1, hidden_size) + q = q.reshape(-1, self.hidden_size) + k = k.reshape(-1, self.hidden_size) + v = v.reshape(-1, self.hidden_size) return torch.concat((q, k, v)) # return qkv.reshape((num_kv_heads, num_key_value_groups+2 , attention_head_dim, hidden_size)).permute((1,0,2,3)).reshape((-1, hidden_size)), @@ -768,4 +808,8 @@ class HunYuanMoEV1ForCausalLM(nn.Module): ) -EntryClass = HunYuanMoEV1ForCausalLM +class HunYuanDenseV1ForCausalLM(HunYuanMoEV1ForCausalLM): + pass + + +EntryClass = [HunYuanMoEV1ForCausalLM, HunYuanDenseV1ForCausalLM]