Support DP MLA (#1970)
This commit is contained in:
@@ -141,6 +141,7 @@ class ModelRunner:
|
||||
"torchao_config": server_args.torchao_config,
|
||||
"disable_penalizer": server_args.disable_penalizer,
|
||||
"disable_nan_detection": server_args.disable_nan_detection,
|
||||
"enable_dp_attention": server_args.enable_dp_attention,
|
||||
}
|
||||
)
|
||||
|
||||
@@ -592,11 +593,18 @@ class ModelRunner:
|
||||
get_embedding=True,
|
||||
)
|
||||
|
||||
def forward_idle(self, forward_batch: ForwardBatch):
|
||||
return self.model.forward(
|
||||
forward_batch.input_ids, forward_batch.positions, forward_batch
|
||||
)
|
||||
|
||||
def forward(self, forward_batch: ForwardBatch) -> LogitsProcessorOutput:
|
||||
if forward_batch.forward_mode.is_decode():
|
||||
return self.forward_decode(forward_batch)
|
||||
elif forward_batch.forward_mode.is_extend():
|
||||
return self.forward_extend(forward_batch)
|
||||
elif forward_batch.forward_mode.is_idle():
|
||||
return self.forward_idle(forward_batch)
|
||||
else:
|
||||
raise ValueError(f"Invaid forward mode: {forward_batch.forward_mode}")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user