[Feature] Comprehensive Hybrid Parallelism Support (#6389)

This commit is contained in:
Cheng Wan
2025-06-20 14:43:11 -07:00
committed by GitHub
parent 0998808009
commit e879d8b7a8
14 changed files with 3689 additions and 108 deletions

View File

@@ -1621,8 +1621,6 @@ class DeepseekV2Model(nn.Module):
)
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.dp_size = get_local_attention_dp_size()
def get_input_embeddings(self) -> torch.Tensor:
return self.embed_tokens
@@ -1706,7 +1704,6 @@ class DeepseekV2ForCausalLM(nn.Module):
use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"],
)
self.logits_processor = LogitsProcessor(config)
self.dp_size = get_local_attention_dp_size()
self._routed_experts_weights_of_layer = LazyValue(
lambda: {