Support DP MLA (#1970)

This commit is contained in:
Ke Bao
2024-11-16 17:01:43 +08:00
committed by GitHub
parent 2f2e07439c
commit 976bc302e5
12 changed files with 395 additions and 63 deletions

View File

@@ -141,6 +141,7 @@ class ModelRunner:
"torchao_config": server_args.torchao_config,
"disable_penalizer": server_args.disable_penalizer,
"disable_nan_detection": server_args.disable_nan_detection,
"enable_dp_attention": server_args.enable_dp_attention,
}
)
@@ -592,11 +593,18 @@ class ModelRunner:
get_embedding=True,
)
def forward_idle(self, forward_batch: ForwardBatch):
return self.model.forward(
forward_batch.input_ids, forward_batch.positions, forward_batch
)
def forward(self, forward_batch: ForwardBatch) -> LogitsProcessorOutput:
if forward_batch.forward_mode.is_decode():
return self.forward_decode(forward_batch)
elif forward_batch.forward_mode.is_extend():
return self.forward_extend(forward_batch)
elif forward_batch.forward_mode.is_idle():
return self.forward_idle(forward_batch)
else:
raise ValueError(f"Invaid forward mode: {forward_batch.forward_mode}")