perf : optimize memory for deepseek mtp (#2713)
### What this PR does / why we need it? delete the temp tensor to optimize memory for deepseek mtp for torchair case - vLLM version: v0.11.0rc3 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0 Signed-off-by: boying <897013703@qq.com>
This commit is contained in:
@@ -102,6 +102,7 @@ class TorchairDeepSeekMultiTokenPredictorLayer(DeepSeekMultiTokenPredictorLayer
|
||||
hidden_states = self.eh_proj(
|
||||
torch.cat([inputs_embeds, previous_hidden_states], dim=-1))
|
||||
|
||||
del inputs_embeds, previous_hidden_states
|
||||
replace_allreduce = hidden_states.shape[0] % self.tp_size == 0
|
||||
|
||||
hidden_states, residual = self.mtp_block(
|
||||
|
||||
Reference in New Issue
Block a user