Support DP MLA (#1970)

This commit is contained in:
Ke Bao
2024-11-16 17:01:43 +08:00
committed by GitHub
parent 2f2e07439c
commit 976bc302e5
12 changed files with 395 additions and 63 deletions

View File

@@ -56,6 +56,8 @@ class ForwardMode(IntEnum):
DECODE = auto()
# Contains both EXTEND and DECODE.
MIXED = auto()
# No sequence to forward. For data parallel attention, some workers wil be IDLE if no sequence allocated.
IDLE = auto()
def is_prefill(self):
return self == ForwardMode.PREFILL
@@ -69,6 +71,9 @@ class ForwardMode(IntEnum):
def is_mixed(self):
return self == ForwardMode.MIXED
def is_idle(self):
return self == ForwardMode.IDLE
@dataclass
class ForwardBatch:
@@ -128,6 +133,10 @@ class ForwardBatch:
# For Qwen2-VL
mrope_positions: torch.Tensor = None
# For DP attention
global_num_tokens: Optional[List[int]] = None
gathered_buffer: Optional[torch.Tensor] = None
def compute_mrope_positions(
self, model_runner: ModelRunner, batch: ModelWorkerBatch
):
@@ -209,10 +218,22 @@ class ForwardBatch:
seq_lens_sum=batch.seq_lens_sum,
return_logprob=batch.return_logprob,
top_logprobs_nums=batch.top_logprobs_nums,
global_num_tokens=batch.global_num_tokens,
lora_paths=batch.lora_paths,
sampling_info=batch.sampling_info,
)
if ret.global_num_tokens is not None:
max_len = max(ret.global_num_tokens)
ret.gathered_buffer = torch.zeros(
(max_len * model_runner.tp_size, model_runner.model_config.hidden_size),
dtype=model_runner.dtype,
device=device,
)
if ret.forward_mode.is_idle():
return ret
# Init position information
if not ret.forward_mode.is_decode():
ret.positions = torch.concat(

View File

@@ -141,6 +141,7 @@ class ModelRunner:
"torchao_config": server_args.torchao_config,
"disable_penalizer": server_args.disable_penalizer,
"disable_nan_detection": server_args.disable_nan_detection,
"enable_dp_attention": server_args.enable_dp_attention,
}
)
@@ -592,11 +593,18 @@ class ModelRunner:
get_embedding=True,
)
def forward_idle(self, forward_batch: ForwardBatch):
return self.model.forward(
forward_batch.input_ids, forward_batch.positions, forward_batch
)
def forward(self, forward_batch: ForwardBatch) -> LogitsProcessorOutput:
if forward_batch.forward_mode.is_decode():
return self.forward_decode(forward_batch)
elif forward_batch.forward_mode.is_extend():
return self.forward_extend(forward_batch)
elif forward_batch.forward_mode.is_idle():
return self.forward_idle(forward_batch)
else:
raise ValueError(f"Invaid forward mode: {forward_batch.forward_mode}")