Support multi-node DP attention (#2925)

Co-authored-by: dhou-xai <dhou@x.ai>
This commit is contained in:
Lianmin Zheng
2025-01-16 11:15:00 -08:00
committed by GitHub
parent 58f3f2b840
commit 8b6ce52e92
16 changed files with 287 additions and 137 deletions

View File

@@ -855,10 +855,9 @@ class DeepseekV2ForCausalLM(nn.Module):
forward_batch: ForwardBatch,
) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, forward_batch)
if not forward_batch.forward_mode.is_idle():
return self.logits_processor(
input_ids, hidden_states, self.lm_head, forward_batch
)
return self.logits_processor(
input_ids, hidden_states, self.lm_head, forward_batch
)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [