Support DP MLA (#1970)
This commit is contained in:
@@ -56,6 +56,8 @@ class ForwardMode(IntEnum):
|
||||
DECODE = auto()
|
||||
# Contains both EXTEND and DECODE.
|
||||
MIXED = auto()
|
||||
# No sequence to forward. For data parallel attention, some workers wil be IDLE if no sequence allocated.
|
||||
IDLE = auto()
|
||||
|
||||
def is_prefill(self):
|
||||
return self == ForwardMode.PREFILL
|
||||
@@ -69,6 +71,9 @@ class ForwardMode(IntEnum):
|
||||
def is_mixed(self):
|
||||
return self == ForwardMode.MIXED
|
||||
|
||||
def is_idle(self):
|
||||
return self == ForwardMode.IDLE
|
||||
|
||||
|
||||
@dataclass
|
||||
class ForwardBatch:
|
||||
@@ -128,6 +133,10 @@ class ForwardBatch:
|
||||
# For Qwen2-VL
|
||||
mrope_positions: torch.Tensor = None
|
||||
|
||||
# For DP attention
|
||||
global_num_tokens: Optional[List[int]] = None
|
||||
gathered_buffer: Optional[torch.Tensor] = None
|
||||
|
||||
def compute_mrope_positions(
|
||||
self, model_runner: ModelRunner, batch: ModelWorkerBatch
|
||||
):
|
||||
@@ -209,10 +218,22 @@ class ForwardBatch:
|
||||
seq_lens_sum=batch.seq_lens_sum,
|
||||
return_logprob=batch.return_logprob,
|
||||
top_logprobs_nums=batch.top_logprobs_nums,
|
||||
global_num_tokens=batch.global_num_tokens,
|
||||
lora_paths=batch.lora_paths,
|
||||
sampling_info=batch.sampling_info,
|
||||
)
|
||||
|
||||
if ret.global_num_tokens is not None:
|
||||
max_len = max(ret.global_num_tokens)
|
||||
ret.gathered_buffer = torch.zeros(
|
||||
(max_len * model_runner.tp_size, model_runner.model_config.hidden_size),
|
||||
dtype=model_runner.dtype,
|
||||
device=device,
|
||||
)
|
||||
|
||||
if ret.forward_mode.is_idle():
|
||||
return ret
|
||||
|
||||
# Init position information
|
||||
if not ret.forward_mode.is_decode():
|
||||
ret.positions = torch.concat(
|
||||
|
||||
@@ -141,6 +141,7 @@ class ModelRunner:
|
||||
"torchao_config": server_args.torchao_config,
|
||||
"disable_penalizer": server_args.disable_penalizer,
|
||||
"disable_nan_detection": server_args.disable_nan_detection,
|
||||
"enable_dp_attention": server_args.enable_dp_attention,
|
||||
}
|
||||
)
|
||||
|
||||
@@ -592,11 +593,18 @@ class ModelRunner:
|
||||
get_embedding=True,
|
||||
)
|
||||
|
||||
def forward_idle(self, forward_batch: ForwardBatch):
|
||||
return self.model.forward(
|
||||
forward_batch.input_ids, forward_batch.positions, forward_batch
|
||||
)
|
||||
|
||||
def forward(self, forward_batch: ForwardBatch) -> LogitsProcessorOutput:
|
||||
if forward_batch.forward_mode.is_decode():
|
||||
return self.forward_decode(forward_batch)
|
||||
elif forward_batch.forward_mode.is_extend():
|
||||
return self.forward_extend(forward_batch)
|
||||
elif forward_batch.forward_mode.is_idle():
|
||||
return self.forward_idle(forward_batch)
|
||||
else:
|
||||
raise ValueError(f"Invaid forward mode: {forward_batch.forward_mode}")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user