Support multi-node DP attention (#2925)
Co-authored-by: dhou-xai <dhou@x.ai>
This commit is contained in:
@@ -855,10 +855,9 @@ class DeepseekV2ForCausalLM(nn.Module):
|
||||
forward_batch: ForwardBatch,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.model(input_ids, positions, forward_batch)
|
||||
if not forward_batch.forward_mode.is_idle():
|
||||
return self.logits_processor(
|
||||
input_ids, hidden_states, self.lm_head, forward_batch
|
||||
)
|
||||
return self.logits_processor(
|
||||
input_ids, hidden_states, self.lm_head, forward_batch
|
||||
)
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
stacked_params_mapping = [
|
||||
|
||||
Reference in New Issue
Block a user