Support DP MLA (#1970)
This commit is contained in:
1
.github/workflows/pr-test.yml
vendored
1
.github/workflows/pr-test.yml
vendored
@@ -244,6 +244,7 @@ jobs:
|
|||||||
cd test/srt
|
cd test/srt
|
||||||
python3 test_mla.py
|
python3 test_mla.py
|
||||||
python3 test_mla_fp8.py
|
python3 test_mla_fp8.py
|
||||||
|
python3 test_dp_attention.py
|
||||||
|
|
||||||
- name: Evaluate data parallelism accuracy (DP=2)
|
- name: Evaluate data parallelism accuracy (DP=2)
|
||||||
timeout-minutes: 10
|
timeout-minutes: 10
|
||||||
|
|||||||
@@ -28,9 +28,13 @@ class TritonAttnBackend(AttentionBackend):
|
|||||||
|
|
||||||
self.decode_attention_fwd = decode_attention_fwd
|
self.decode_attention_fwd = decode_attention_fwd
|
||||||
self.extend_attention_fwd = extend_attention_fwd
|
self.extend_attention_fwd = extend_attention_fwd
|
||||||
self.num_head = (
|
|
||||||
model_runner.model_config.num_attention_heads // model_runner.tp_size
|
if model_runner.server_args.enable_dp_attention:
|
||||||
)
|
self.num_head = model_runner.model_config.num_attention_heads
|
||||||
|
else:
|
||||||
|
self.num_head = (
|
||||||
|
model_runner.model_config.num_attention_heads // model_runner.tp_size
|
||||||
|
)
|
||||||
|
|
||||||
if global_server_args_dict.get("triton_attention_reduce_in_fp32", False):
|
if global_server_args_dict.get("triton_attention_reduce_in_fp32", False):
|
||||||
self.reduce_dtype = torch.float32
|
self.reduce_dtype = torch.float32
|
||||||
|
|||||||
@@ -81,20 +81,34 @@ class DataParallelController:
|
|||||||
# Start data parallel workers
|
# Start data parallel workers
|
||||||
base_gpu_id = 0
|
base_gpu_id = 0
|
||||||
self.workers = []
|
self.workers = []
|
||||||
|
scheduler_pipe_readers = []
|
||||||
for dp_rank in range(server_args.dp_size):
|
for dp_rank in range(server_args.dp_size):
|
||||||
tmp_port_args = PortArgs.init_new(server_args)
|
tmp_port_args = PortArgs.init_new(server_args)
|
||||||
tmp_port_args.tokenizer_ipc_name = port_args.tokenizer_ipc_name
|
tmp_port_args.tokenizer_ipc_name = port_args.tokenizer_ipc_name
|
||||||
tmp_port_args.detokenizer_ipc_name = port_args.detokenizer_ipc_name
|
tmp_port_args.detokenizer_ipc_name = port_args.detokenizer_ipc_name
|
||||||
|
|
||||||
send_to = self.launch_tensor_parallel_group(
|
if server_args.enable_dp_attention:
|
||||||
server_args,
|
# Share workers for DP and TP
|
||||||
tmp_port_args,
|
send_to, reader = self.launch_tensor_parallel_process(
|
||||||
base_gpu_id,
|
server_args,
|
||||||
dp_rank,
|
tmp_port_args,
|
||||||
)
|
base_gpu_id,
|
||||||
|
dp_rank,
|
||||||
|
)
|
||||||
|
base_gpu_id += 1
|
||||||
|
scheduler_pipe_readers.append(reader)
|
||||||
|
else:
|
||||||
|
send_to = self.launch_tensor_parallel_group(
|
||||||
|
server_args,
|
||||||
|
tmp_port_args,
|
||||||
|
base_gpu_id,
|
||||||
|
dp_rank,
|
||||||
|
)
|
||||||
|
base_gpu_id += server_args.tp_size
|
||||||
self.workers.append(send_to)
|
self.workers.append(send_to)
|
||||||
base_gpu_id += server_args.tp_size
|
|
||||||
|
for reader in scheduler_pipe_readers:
|
||||||
|
reader.recv()
|
||||||
|
|
||||||
def launch_tensor_parallel_group(
|
def launch_tensor_parallel_group(
|
||||||
self,
|
self,
|
||||||
@@ -132,6 +146,27 @@ class DataParallelController:
|
|||||||
|
|
||||||
return send_to
|
return send_to
|
||||||
|
|
||||||
|
def launch_tensor_parallel_process(
|
||||||
|
self,
|
||||||
|
server_args: ServerArgs,
|
||||||
|
port_args: PortArgs,
|
||||||
|
base_gpu_id: int,
|
||||||
|
dp_rank: int,
|
||||||
|
):
|
||||||
|
reader, writer = mp.Pipe(duplex=False)
|
||||||
|
gpu_id = base_gpu_id
|
||||||
|
tp_rank = dp_rank
|
||||||
|
proc = mp.Process(
|
||||||
|
target=run_scheduler_process,
|
||||||
|
args=(server_args, port_args, gpu_id, tp_rank, dp_rank, writer),
|
||||||
|
)
|
||||||
|
proc.start()
|
||||||
|
send_to = get_zmq_socket(
|
||||||
|
self.context, zmq.PUSH, port_args.scheduler_input_ipc_name
|
||||||
|
)
|
||||||
|
|
||||||
|
return send_to, reader
|
||||||
|
|
||||||
def round_robin_scheduler(self, req):
|
def round_robin_scheduler(self, req):
|
||||||
self.workers[self.round_robin_counter].send_pyobj(req)
|
self.workers[self.round_robin_counter].send_pyobj(req)
|
||||||
self.round_robin_counter = (self.round_robin_counter + 1) % len(self.workers)
|
self.round_robin_counter = (self.round_robin_counter + 1) % len(self.workers)
|
||||||
|
|||||||
@@ -56,6 +56,7 @@ global_server_args_dict = {
|
|||||||
"disable_mla": ServerArgs.disable_mla,
|
"disable_mla": ServerArgs.disable_mla,
|
||||||
"torchao_config": ServerArgs.torchao_config,
|
"torchao_config": ServerArgs.torchao_config,
|
||||||
"disable_nan_detection": ServerArgs.disable_nan_detection,
|
"disable_nan_detection": ServerArgs.disable_nan_detection,
|
||||||
|
"enable_dp_attention": ServerArgs.enable_dp_attention,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@@ -450,6 +451,9 @@ class ScheduleBatch:
|
|||||||
# The sum of all sequence lengths
|
# The sum of all sequence lengths
|
||||||
seq_lens_sum: int = None
|
seq_lens_sum: int = None
|
||||||
|
|
||||||
|
# For DP attention
|
||||||
|
global_num_tokens: Optional[List[int]] = None
|
||||||
|
|
||||||
# For processing logprobs
|
# For processing logprobs
|
||||||
return_logprob: bool = False
|
return_logprob: bool = False
|
||||||
top_logprobs_nums: Optional[List[int]] = None
|
top_logprobs_nums: Optional[List[int]] = None
|
||||||
@@ -858,6 +862,16 @@ class ScheduleBatch:
|
|||||||
# Reset the encoder cached status
|
# Reset the encoder cached status
|
||||||
self.encoder_cached = [True] * len(self.reqs)
|
self.encoder_cached = [True] * len(self.reqs)
|
||||||
|
|
||||||
|
def prepare_for_idle(self):
|
||||||
|
self.forward_mode = ForwardMode.IDLE
|
||||||
|
self.input_ids = torch.empty(0, dtype=torch.int32).to(
|
||||||
|
self.device, non_blocking=True
|
||||||
|
)
|
||||||
|
self.seq_lens = torch.empty(0, dtype=torch.int32).to(
|
||||||
|
self.device, non_blocking=True
|
||||||
|
)
|
||||||
|
self.extend_num_tokens = 0
|
||||||
|
|
||||||
def prepare_for_decode(self, enable_overlap: bool = False):
|
def prepare_for_decode(self, enable_overlap: bool = False):
|
||||||
self.forward_mode = ForwardMode.DECODE
|
self.forward_mode = ForwardMode.DECODE
|
||||||
|
|
||||||
@@ -969,17 +983,18 @@ class ScheduleBatch:
|
|||||||
self.has_grammar = self.has_grammar or other.has_grammar
|
self.has_grammar = self.has_grammar or other.has_grammar
|
||||||
|
|
||||||
def get_model_worker_batch(self):
|
def get_model_worker_batch(self):
|
||||||
if self.forward_mode.is_decode():
|
if self.forward_mode.is_decode() or self.forward_mode.is_idle():
|
||||||
extend_seq_lens = extend_prefix_lens = extend_logprob_start_lens = None
|
extend_seq_lens = extend_prefix_lens = extend_logprob_start_lens = None
|
||||||
else:
|
else:
|
||||||
extend_seq_lens = self.extend_lens
|
extend_seq_lens = self.extend_lens
|
||||||
extend_prefix_lens = self.prefix_lens
|
extend_prefix_lens = self.prefix_lens
|
||||||
extend_logprob_start_lens = self.extend_logprob_start_lens
|
extend_logprob_start_lens = self.extend_logprob_start_lens
|
||||||
|
|
||||||
if self.has_grammar:
|
if self.sampling_info is not None:
|
||||||
self.sampling_info.grammars = [req.grammar for req in self.reqs]
|
if self.has_grammar:
|
||||||
else:
|
self.sampling_info.grammars = [req.grammar for req in self.reqs]
|
||||||
self.sampling_info.grammars = None
|
else:
|
||||||
|
self.sampling_info.grammars = None
|
||||||
|
|
||||||
global bid
|
global bid
|
||||||
bid += 1
|
bid += 1
|
||||||
@@ -995,6 +1010,7 @@ class ScheduleBatch:
|
|||||||
req_to_token_pool_records=self.req_to_token_pool.get_write_records(),
|
req_to_token_pool_records=self.req_to_token_pool.get_write_records(),
|
||||||
return_logprob=self.return_logprob,
|
return_logprob=self.return_logprob,
|
||||||
top_logprobs_nums=self.top_logprobs_nums,
|
top_logprobs_nums=self.top_logprobs_nums,
|
||||||
|
global_num_tokens=self.global_num_tokens,
|
||||||
extend_num_tokens=self.extend_num_tokens,
|
extend_num_tokens=self.extend_num_tokens,
|
||||||
extend_seq_lens=extend_seq_lens,
|
extend_seq_lens=extend_seq_lens,
|
||||||
extend_prefix_lens=extend_prefix_lens,
|
extend_prefix_lens=extend_prefix_lens,
|
||||||
@@ -1051,6 +1067,9 @@ class ModelWorkerBatch:
|
|||||||
return_logprob: bool
|
return_logprob: bool
|
||||||
top_logprobs_nums: Optional[List[int]]
|
top_logprobs_nums: Optional[List[int]]
|
||||||
|
|
||||||
|
# For DP attention
|
||||||
|
global_num_tokens: Optional[List[int]]
|
||||||
|
|
||||||
# For extend
|
# For extend
|
||||||
extend_num_tokens: Optional[int]
|
extend_num_tokens: Optional[int]
|
||||||
extend_seq_lens: Optional[List[int]]
|
extend_seq_lens: Optional[List[int]]
|
||||||
|
|||||||
@@ -110,7 +110,7 @@ class Scheduler:
|
|||||||
# Init inter-process communication
|
# Init inter-process communication
|
||||||
context = zmq.Context(2)
|
context = zmq.Context(2)
|
||||||
|
|
||||||
if self.tp_rank == 0:
|
if self.tp_rank == 0 or self.server_args.enable_dp_attention:
|
||||||
self.recv_from_tokenizer = get_zmq_socket(
|
self.recv_from_tokenizer = get_zmq_socket(
|
||||||
context, zmq.PULL, port_args.scheduler_input_ipc_name
|
context, zmq.PULL, port_args.scheduler_input_ipc_name
|
||||||
)
|
)
|
||||||
@@ -347,6 +347,10 @@ class Scheduler:
|
|||||||
self.process_input_requests(recv_reqs)
|
self.process_input_requests(recv_reqs)
|
||||||
|
|
||||||
batch = self.get_next_batch_to_run()
|
batch = self.get_next_batch_to_run()
|
||||||
|
|
||||||
|
if self.server_args.enable_dp_attention:
|
||||||
|
batch = self.prepare_dp_attn_batch(batch)
|
||||||
|
|
||||||
self.cur_batch = batch
|
self.cur_batch = batch
|
||||||
|
|
||||||
if batch:
|
if batch:
|
||||||
@@ -361,6 +365,8 @@ class Scheduler:
|
|||||||
self.update_running_batch()
|
self.update_running_batch()
|
||||||
if not self.running_batch:
|
if not self.running_batch:
|
||||||
break
|
break
|
||||||
|
if self.server_args.enable_dp_attention:
|
||||||
|
batch = self.prepare_dp_attn_batch(batch)
|
||||||
result = self.run_batch(batch)
|
result = self.run_batch(batch)
|
||||||
self.process_batch_result(batch, result)
|
self.process_batch_result(batch, result)
|
||||||
else:
|
else:
|
||||||
@@ -396,8 +402,48 @@ class Scheduler:
|
|||||||
|
|
||||||
self.last_batch = batch
|
self.last_batch = batch
|
||||||
|
|
||||||
|
def prepare_dp_attn_batch(self, local_batch: ScheduleBatch):
|
||||||
|
# Check if other DP workers have running batches
|
||||||
|
if local_batch is None:
|
||||||
|
num_tokens = 0
|
||||||
|
elif local_batch.forward_mode.is_decode():
|
||||||
|
num_tokens = local_batch.batch_size()
|
||||||
|
else:
|
||||||
|
num_tokens = local_batch.extend_num_tokens
|
||||||
|
|
||||||
|
local_num_tokens = torch.tensor(
|
||||||
|
num_tokens, dtype=torch.int64, device=self.device
|
||||||
|
)
|
||||||
|
global_num_tokens = torch.empty(
|
||||||
|
self.tp_size, dtype=torch.int64, device=self.device
|
||||||
|
)
|
||||||
|
torch.distributed.all_gather_into_tensor(
|
||||||
|
global_num_tokens,
|
||||||
|
local_num_tokens,
|
||||||
|
group=self.tp_worker.get_tp_device_group(),
|
||||||
|
)
|
||||||
|
|
||||||
|
if local_batch is None and global_num_tokens.max().item() > 0:
|
||||||
|
local_batch = self.get_idle_batch()
|
||||||
|
|
||||||
|
if local_batch is not None:
|
||||||
|
local_batch.global_num_tokens = global_num_tokens.tolist()
|
||||||
|
|
||||||
|
return local_batch
|
||||||
|
|
||||||
|
def get_idle_batch(self):
|
||||||
|
idle_batch = ScheduleBatch.init_new(
|
||||||
|
[],
|
||||||
|
self.req_to_token_pool,
|
||||||
|
self.token_to_kv_pool,
|
||||||
|
self.tree_cache,
|
||||||
|
self.model_config,
|
||||||
|
)
|
||||||
|
idle_batch.prepare_for_idle()
|
||||||
|
return idle_batch
|
||||||
|
|
||||||
def recv_requests(self):
|
def recv_requests(self):
|
||||||
if self.tp_rank == 0:
|
if self.tp_rank == 0 or self.server_args.enable_dp_attention:
|
||||||
recv_reqs = []
|
recv_reqs = []
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
@@ -409,7 +455,7 @@ class Scheduler:
|
|||||||
else:
|
else:
|
||||||
recv_reqs = None
|
recv_reqs = None
|
||||||
|
|
||||||
if self.tp_size != 1:
|
if self.tp_size != 1 and not self.server_args.enable_dp_attention:
|
||||||
recv_reqs = broadcast_pyobj(recv_reqs, self.tp_rank, self.tp_cpu_group)
|
recv_reqs = broadcast_pyobj(recv_reqs, self.tp_rank, self.tp_cpu_group)
|
||||||
return recv_reqs
|
return recv_reqs
|
||||||
|
|
||||||
@@ -812,6 +858,10 @@ class Scheduler:
|
|||||||
logits_output, next_token_ids = self.tp_worker.forward_batch_generation(
|
logits_output, next_token_ids = self.tp_worker.forward_batch_generation(
|
||||||
model_worker_batch
|
model_worker_batch
|
||||||
)
|
)
|
||||||
|
elif batch.forward_mode.is_idle():
|
||||||
|
model_worker_batch = batch.get_model_worker_batch()
|
||||||
|
self.tp_worker.forward_batch_idle(model_worker_batch)
|
||||||
|
return
|
||||||
else:
|
else:
|
||||||
logits_output = None
|
logits_output = None
|
||||||
if self.skip_tokenizer_init:
|
if self.skip_tokenizer_init:
|
||||||
@@ -830,6 +880,8 @@ class Scheduler:
|
|||||||
return ret
|
return ret
|
||||||
|
|
||||||
def process_batch_result(self, batch: ScheduleBatch, result):
|
def process_batch_result(self, batch: ScheduleBatch, result):
|
||||||
|
if batch.forward_mode.is_idle():
|
||||||
|
return
|
||||||
if batch.forward_mode.is_decode():
|
if batch.forward_mode.is_decode():
|
||||||
self.process_batch_result_decode(batch, result)
|
self.process_batch_result_decode(batch, result)
|
||||||
if batch.is_empty():
|
if batch.is_empty():
|
||||||
|
|||||||
@@ -128,12 +128,19 @@ class TpModelWorker:
|
|||||||
def get_tp_cpu_group(self):
|
def get_tp_cpu_group(self):
|
||||||
return self.model_runner.tp_group.cpu_group
|
return self.model_runner.tp_group.cpu_group
|
||||||
|
|
||||||
|
def get_tp_device_group(self):
|
||||||
|
return self.model_runner.tp_group.device_group
|
||||||
|
|
||||||
def get_memory_pool(self):
|
def get_memory_pool(self):
|
||||||
return (
|
return (
|
||||||
self.model_runner.req_to_token_pool,
|
self.model_runner.req_to_token_pool,
|
||||||
self.model_runner.token_to_kv_pool,
|
self.model_runner.token_to_kv_pool,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def forward_batch_idle(self, model_worker_batch: ModelWorkerBatch):
|
||||||
|
forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
|
||||||
|
self.model_runner.forward(forward_batch)
|
||||||
|
|
||||||
def forward_batch_generation(self, model_worker_batch: ModelWorkerBatch):
|
def forward_batch_generation(self, model_worker_batch: ModelWorkerBatch):
|
||||||
forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
|
forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
|
||||||
logits_output = self.model_runner.forward(forward_batch)
|
logits_output = self.model_runner.forward(forward_batch)
|
||||||
|
|||||||
@@ -88,6 +88,9 @@ class TpModelWorkerClient:
|
|||||||
def get_tp_cpu_group(self):
|
def get_tp_cpu_group(self):
|
||||||
return self.worker.get_tp_cpu_group()
|
return self.worker.get_tp_cpu_group()
|
||||||
|
|
||||||
|
def get_tp_device_group(self):
|
||||||
|
return self.worker.get_tp_device_group()
|
||||||
|
|
||||||
def get_memory_pool(self):
|
def get_memory_pool(self):
|
||||||
return (
|
return (
|
||||||
self.worker.model_runner.req_to_token_pool,
|
self.worker.model_runner.req_to_token_pool,
|
||||||
|
|||||||
@@ -56,6 +56,8 @@ class ForwardMode(IntEnum):
|
|||||||
DECODE = auto()
|
DECODE = auto()
|
||||||
# Contains both EXTEND and DECODE.
|
# Contains both EXTEND and DECODE.
|
||||||
MIXED = auto()
|
MIXED = auto()
|
||||||
|
# No sequence to forward. For data parallel attention, some workers wil be IDLE if no sequence allocated.
|
||||||
|
IDLE = auto()
|
||||||
|
|
||||||
def is_prefill(self):
|
def is_prefill(self):
|
||||||
return self == ForwardMode.PREFILL
|
return self == ForwardMode.PREFILL
|
||||||
@@ -69,6 +71,9 @@ class ForwardMode(IntEnum):
|
|||||||
def is_mixed(self):
|
def is_mixed(self):
|
||||||
return self == ForwardMode.MIXED
|
return self == ForwardMode.MIXED
|
||||||
|
|
||||||
|
def is_idle(self):
|
||||||
|
return self == ForwardMode.IDLE
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ForwardBatch:
|
class ForwardBatch:
|
||||||
@@ -128,6 +133,10 @@ class ForwardBatch:
|
|||||||
# For Qwen2-VL
|
# For Qwen2-VL
|
||||||
mrope_positions: torch.Tensor = None
|
mrope_positions: torch.Tensor = None
|
||||||
|
|
||||||
|
# For DP attention
|
||||||
|
global_num_tokens: Optional[List[int]] = None
|
||||||
|
gathered_buffer: Optional[torch.Tensor] = None
|
||||||
|
|
||||||
def compute_mrope_positions(
|
def compute_mrope_positions(
|
||||||
self, model_runner: ModelRunner, batch: ModelWorkerBatch
|
self, model_runner: ModelRunner, batch: ModelWorkerBatch
|
||||||
):
|
):
|
||||||
@@ -209,10 +218,22 @@ class ForwardBatch:
|
|||||||
seq_lens_sum=batch.seq_lens_sum,
|
seq_lens_sum=batch.seq_lens_sum,
|
||||||
return_logprob=batch.return_logprob,
|
return_logprob=batch.return_logprob,
|
||||||
top_logprobs_nums=batch.top_logprobs_nums,
|
top_logprobs_nums=batch.top_logprobs_nums,
|
||||||
|
global_num_tokens=batch.global_num_tokens,
|
||||||
lora_paths=batch.lora_paths,
|
lora_paths=batch.lora_paths,
|
||||||
sampling_info=batch.sampling_info,
|
sampling_info=batch.sampling_info,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if ret.global_num_tokens is not None:
|
||||||
|
max_len = max(ret.global_num_tokens)
|
||||||
|
ret.gathered_buffer = torch.zeros(
|
||||||
|
(max_len * model_runner.tp_size, model_runner.model_config.hidden_size),
|
||||||
|
dtype=model_runner.dtype,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
|
||||||
|
if ret.forward_mode.is_idle():
|
||||||
|
return ret
|
||||||
|
|
||||||
# Init position information
|
# Init position information
|
||||||
if not ret.forward_mode.is_decode():
|
if not ret.forward_mode.is_decode():
|
||||||
ret.positions = torch.concat(
|
ret.positions = torch.concat(
|
||||||
|
|||||||
@@ -141,6 +141,7 @@ class ModelRunner:
|
|||||||
"torchao_config": server_args.torchao_config,
|
"torchao_config": server_args.torchao_config,
|
||||||
"disable_penalizer": server_args.disable_penalizer,
|
"disable_penalizer": server_args.disable_penalizer,
|
||||||
"disable_nan_detection": server_args.disable_nan_detection,
|
"disable_nan_detection": server_args.disable_nan_detection,
|
||||||
|
"enable_dp_attention": server_args.enable_dp_attention,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -592,11 +593,18 @@ class ModelRunner:
|
|||||||
get_embedding=True,
|
get_embedding=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def forward_idle(self, forward_batch: ForwardBatch):
|
||||||
|
return self.model.forward(
|
||||||
|
forward_batch.input_ids, forward_batch.positions, forward_batch
|
||||||
|
)
|
||||||
|
|
||||||
def forward(self, forward_batch: ForwardBatch) -> LogitsProcessorOutput:
|
def forward(self, forward_batch: ForwardBatch) -> LogitsProcessorOutput:
|
||||||
if forward_batch.forward_mode.is_decode():
|
if forward_batch.forward_mode.is_decode():
|
||||||
return self.forward_decode(forward_batch)
|
return self.forward_decode(forward_batch)
|
||||||
elif forward_batch.forward_mode.is_extend():
|
elif forward_batch.forward_mode.is_extend():
|
||||||
return self.forward_extend(forward_batch)
|
return self.forward_extend(forward_batch)
|
||||||
|
elif forward_batch.forward_mode.is_idle():
|
||||||
|
return self.forward_idle(forward_batch)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Invaid forward mode: {forward_batch.forward_mode}")
|
raise ValueError(f"Invaid forward mode: {forward_batch.forward_mode}")
|
||||||
|
|
||||||
|
|||||||
@@ -22,7 +22,9 @@ import torch
|
|||||||
from torch import nn
|
from torch import nn
|
||||||
from transformers import PretrainedConfig
|
from transformers import PretrainedConfig
|
||||||
from vllm.distributed import (
|
from vllm.distributed import (
|
||||||
|
get_tensor_model_parallel_rank,
|
||||||
get_tensor_model_parallel_world_size,
|
get_tensor_model_parallel_world_size,
|
||||||
|
get_tp_group,
|
||||||
tensor_model_parallel_all_reduce,
|
tensor_model_parallel_all_reduce,
|
||||||
)
|
)
|
||||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||||
@@ -338,6 +340,7 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|||||||
cache_config=None,
|
cache_config=None,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
layer_id=None,
|
layer_id=None,
|
||||||
|
use_dp=False,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.layer_id = layer_id
|
self.layer_id = layer_id
|
||||||
@@ -351,29 +354,80 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|||||||
self.num_heads = num_heads
|
self.num_heads = num_heads
|
||||||
tp_size = get_tensor_model_parallel_world_size()
|
tp_size = get_tensor_model_parallel_world_size()
|
||||||
assert num_heads % tp_size == 0
|
assert num_heads % tp_size == 0
|
||||||
self.num_local_heads = num_heads // tp_size
|
self.num_local_heads = num_heads if use_dp else num_heads // tp_size
|
||||||
self.scaling = self.qk_head_dim**-0.5
|
self.scaling = self.qk_head_dim**-0.5
|
||||||
self.rope_theta = rope_theta
|
self.rope_theta = rope_theta
|
||||||
self.max_position_embeddings = max_position_embeddings
|
self.max_position_embeddings = max_position_embeddings
|
||||||
|
|
||||||
if self.q_lora_rank is not None:
|
if use_dp:
|
||||||
self.q_a_proj = ReplicatedLinear(
|
# For data parallel attention
|
||||||
self.hidden_size,
|
if self.q_lora_rank is not None:
|
||||||
self.q_lora_rank,
|
self.q_a_proj = ReplicatedLinear(
|
||||||
|
self.hidden_size,
|
||||||
|
self.q_lora_rank,
|
||||||
|
bias=False,
|
||||||
|
quant_config=quant_config,
|
||||||
|
)
|
||||||
|
self.q_a_layernorm = RMSNorm(self.q_lora_rank, eps=config.rms_norm_eps)
|
||||||
|
self.q_b_proj = ReplicatedLinear(
|
||||||
|
q_lora_rank,
|
||||||
|
self.num_heads * self.qk_head_dim,
|
||||||
|
bias=False,
|
||||||
|
quant_config=quant_config,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.q_proj = ReplicatedLinear(
|
||||||
|
self.hidden_size,
|
||||||
|
self.num_heads * self.qk_head_dim,
|
||||||
|
bias=False,
|
||||||
|
quant_config=quant_config,
|
||||||
|
)
|
||||||
|
self.kv_b_proj = ReplicatedLinear(
|
||||||
|
self.kv_lora_rank,
|
||||||
|
self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
|
||||||
bias=False,
|
bias=False,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
)
|
)
|
||||||
self.q_a_layernorm = RMSNorm(self.q_lora_rank, eps=config.rms_norm_eps)
|
# O projection.
|
||||||
self.q_b_proj = ColumnParallelLinear(
|
self.o_proj = ReplicatedLinear(
|
||||||
q_lora_rank,
|
self.num_heads * self.v_head_dim,
|
||||||
self.num_heads * self.qk_head_dim,
|
self.hidden_size,
|
||||||
bias=False,
|
bias=False,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self.q_proj = ColumnParallelLinear(
|
# For tensor parallel attention
|
||||||
|
if self.q_lora_rank is not None:
|
||||||
|
self.q_a_proj = ReplicatedLinear(
|
||||||
|
self.hidden_size,
|
||||||
|
self.q_lora_rank,
|
||||||
|
bias=False,
|
||||||
|
quant_config=quant_config,
|
||||||
|
)
|
||||||
|
self.q_a_layernorm = RMSNorm(self.q_lora_rank, eps=config.rms_norm_eps)
|
||||||
|
self.q_b_proj = ColumnParallelLinear(
|
||||||
|
q_lora_rank,
|
||||||
|
self.num_heads * self.qk_head_dim,
|
||||||
|
bias=False,
|
||||||
|
quant_config=quant_config,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.q_proj = ColumnParallelLinear(
|
||||||
|
self.hidden_size,
|
||||||
|
self.num_heads * self.qk_head_dim,
|
||||||
|
bias=False,
|
||||||
|
quant_config=quant_config,
|
||||||
|
)
|
||||||
|
self.kv_b_proj = ColumnParallelLinear(
|
||||||
|
self.kv_lora_rank,
|
||||||
|
self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
|
||||||
|
bias=False,
|
||||||
|
quant_config=quant_config,
|
||||||
|
)
|
||||||
|
# O projection.
|
||||||
|
self.o_proj = RowParallelLinear(
|
||||||
|
self.num_heads * self.v_head_dim,
|
||||||
self.hidden_size,
|
self.hidden_size,
|
||||||
self.num_heads * self.qk_head_dim,
|
|
||||||
bias=False,
|
bias=False,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
)
|
)
|
||||||
@@ -385,19 +439,6 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
)
|
)
|
||||||
self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, eps=config.rms_norm_eps)
|
self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, eps=config.rms_norm_eps)
|
||||||
self.kv_b_proj = ColumnParallelLinear(
|
|
||||||
self.kv_lora_rank,
|
|
||||||
self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
|
|
||||||
bias=False,
|
|
||||||
quant_config=quant_config,
|
|
||||||
)
|
|
||||||
# O projection.
|
|
||||||
self.o_proj = RowParallelLinear(
|
|
||||||
self.num_heads * self.v_head_dim,
|
|
||||||
self.hidden_size,
|
|
||||||
bias=False,
|
|
||||||
quant_config=quant_config,
|
|
||||||
)
|
|
||||||
rope_scaling["rope_type"] = "deepseek_yarn"
|
rope_scaling["rope_type"] = "deepseek_yarn"
|
||||||
self.rotary_emb = get_rope(
|
self.rotary_emb = get_rope(
|
||||||
qk_rope_head_dim,
|
qk_rope_head_dim,
|
||||||
@@ -491,6 +532,36 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
def all_gather(
|
||||||
|
input_tensor: torch.Tensor, forward_batch: ForwardBatch, rank, world_size, group
|
||||||
|
):
|
||||||
|
if world_size == 1:
|
||||||
|
return input_tensor
|
||||||
|
|
||||||
|
all_lens = forward_batch.global_num_tokens
|
||||||
|
max_len = max(forward_batch.global_num_tokens)
|
||||||
|
|
||||||
|
padded_tensor = torch.nn.functional.pad(
|
||||||
|
input_tensor, (0, 0, 0, max_len - input_tensor.shape[0])
|
||||||
|
)
|
||||||
|
|
||||||
|
torch.distributed.all_gather_into_tensor(
|
||||||
|
forward_batch.gathered_buffer, padded_tensor, group=group
|
||||||
|
)
|
||||||
|
|
||||||
|
gathered_tensors = torch.concat(
|
||||||
|
[
|
||||||
|
forward_batch.gathered_buffer[i * max_len : i * max_len + all_lens[i]]
|
||||||
|
for i in range(world_size)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
start_index = 0 if rank == 0 else sum(all_lens[:rank])
|
||||||
|
end_index = start_index + all_lens[rank]
|
||||||
|
|
||||||
|
return gathered_tensors, start_index, end_index
|
||||||
|
|
||||||
|
|
||||||
class DeepseekV2DecoderLayer(nn.Module):
|
class DeepseekV2DecoderLayer(nn.Module):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -505,6 +576,14 @@ class DeepseekV2DecoderLayer(nn.Module):
|
|||||||
rope_theta = getattr(config, "rope_theta", 10000)
|
rope_theta = getattr(config, "rope_theta", 10000)
|
||||||
rope_scaling = getattr(config, "rope_scaling", None)
|
rope_scaling = getattr(config, "rope_scaling", None)
|
||||||
max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
|
max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
|
||||||
|
self.enable_dp_attention = (
|
||||||
|
not global_server_args_dict["disable_mla"]
|
||||||
|
and global_server_args_dict["enable_dp_attention"]
|
||||||
|
)
|
||||||
|
if self.enable_dp_attention:
|
||||||
|
self.tp_rank = get_tensor_model_parallel_rank()
|
||||||
|
self.tp_size = get_tensor_model_parallel_world_size()
|
||||||
|
self.tp_group = get_tp_group().device_group
|
||||||
if not global_server_args_dict["disable_mla"]:
|
if not global_server_args_dict["disable_mla"]:
|
||||||
self.self_attn = DeepseekV2AttentionMLA(
|
self.self_attn = DeepseekV2AttentionMLA(
|
||||||
config=config,
|
config=config,
|
||||||
@@ -523,6 +602,7 @@ class DeepseekV2DecoderLayer(nn.Module):
|
|||||||
cache_config=cache_config,
|
cache_config=cache_config,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
layer_id=layer_id,
|
layer_id=layer_id,
|
||||||
|
use_dp=self.enable_dp_attention,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self.self_attn = DeepseekV2Attention(
|
self.self_attn = DeepseekV2Attention(
|
||||||
@@ -569,20 +649,32 @@ class DeepseekV2DecoderLayer(nn.Module):
|
|||||||
residual: Optional[torch.Tensor],
|
residual: Optional[torch.Tensor],
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
# Self Attention
|
# Self Attention
|
||||||
if residual is None:
|
if not forward_batch.forward_mode.is_idle():
|
||||||
residual = hidden_states
|
if residual is None:
|
||||||
hidden_states = self.input_layernorm(hidden_states)
|
residual = hidden_states
|
||||||
else:
|
hidden_states = self.input_layernorm(hidden_states)
|
||||||
hidden_states, residual = self.input_layernorm(hidden_states, residual)
|
else:
|
||||||
hidden_states = self.self_attn(
|
hidden_states, residual = self.input_layernorm(hidden_states, residual)
|
||||||
positions=positions,
|
|
||||||
hidden_states=hidden_states,
|
hidden_states = self.self_attn(
|
||||||
forward_batch=forward_batch,
|
positions=positions,
|
||||||
)
|
hidden_states=hidden_states,
|
||||||
|
forward_batch=forward_batch,
|
||||||
|
)
|
||||||
|
hidden_states, residual = self.post_attention_layernorm(
|
||||||
|
hidden_states, residual
|
||||||
|
)
|
||||||
|
|
||||||
# Fully Connected
|
# Fully Connected
|
||||||
hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
|
if self.enable_dp_attention:
|
||||||
hidden_states = self.mlp(hidden_states)
|
hidden_states, start_idx, end_idx = all_gather(
|
||||||
|
hidden_states, forward_batch, self.tp_rank, self.tp_size, self.tp_group
|
||||||
|
)
|
||||||
|
hidden_states = self.mlp(hidden_states)
|
||||||
|
hidden_states = hidden_states[start_idx:end_idx]
|
||||||
|
else:
|
||||||
|
hidden_states = self.mlp(hidden_states)
|
||||||
|
|
||||||
return hidden_states, residual
|
return hidden_states, residual
|
||||||
|
|
||||||
|
|
||||||
@@ -603,6 +695,7 @@ class DeepseekV2Model(nn.Module):
|
|||||||
self.embed_tokens = VocabParallelEmbedding(
|
self.embed_tokens = VocabParallelEmbedding(
|
||||||
config.vocab_size,
|
config.vocab_size,
|
||||||
config.hidden_size,
|
config.hidden_size,
|
||||||
|
enable_tp=not global_server_args_dict["enable_dp_attention"],
|
||||||
)
|
)
|
||||||
self.layers = nn.ModuleList(
|
self.layers = nn.ModuleList(
|
||||||
[
|
[
|
||||||
@@ -630,7 +723,8 @@ class DeepseekV2Model(nn.Module):
|
|||||||
hidden_states, residual = layer(
|
hidden_states, residual = layer(
|
||||||
positions, hidden_states, forward_batch, residual
|
positions, hidden_states, forward_batch, residual
|
||||||
)
|
)
|
||||||
hidden_states, _ = self.norm(hidden_states, residual)
|
if not forward_batch.forward_mode.is_idle():
|
||||||
|
hidden_states, _ = self.norm(hidden_states, residual)
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
@@ -646,10 +740,18 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|||||||
self.config = config
|
self.config = config
|
||||||
self.quant_config = quant_config
|
self.quant_config = quant_config
|
||||||
self.model = DeepseekV2Model(config, cache_config, quant_config)
|
self.model = DeepseekV2Model(config, cache_config, quant_config)
|
||||||
self.lm_head = ParallelLMHead(
|
if global_server_args_dict["enable_dp_attention"]:
|
||||||
config.vocab_size, config.hidden_size, quant_config=quant_config
|
self.lm_head = ReplicatedLinear(
|
||||||
)
|
config.hidden_size,
|
||||||
self.logits_processor = LogitsProcessor(config)
|
config.vocab_size,
|
||||||
|
bias=False,
|
||||||
|
)
|
||||||
|
self.logits_processor = LogitsProcessor(config, skip_all_gather=True)
|
||||||
|
else:
|
||||||
|
self.lm_head = ParallelLMHead(
|
||||||
|
config.vocab_size, config.hidden_size, quant_config=quant_config
|
||||||
|
)
|
||||||
|
self.logits_processor = LogitsProcessor(config)
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def forward(
|
def forward(
|
||||||
@@ -659,9 +761,10 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|||||||
forward_batch: ForwardBatch,
|
forward_batch: ForwardBatch,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
hidden_states = self.model(input_ids, positions, forward_batch)
|
hidden_states = self.model(input_ids, positions, forward_batch)
|
||||||
return self.logits_processor(
|
if not forward_batch.forward_mode.is_idle():
|
||||||
input_ids, hidden_states, self.lm_head.weight, forward_batch
|
return self.logits_processor(
|
||||||
)
|
input_ids, hidden_states, self.lm_head.weight, forward_batch
|
||||||
|
)
|
||||||
|
|
||||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||||
stacked_params_mapping = [
|
stacked_params_mapping = [
|
||||||
|
|||||||
@@ -129,6 +129,7 @@ class ServerArgs:
|
|||||||
disable_nan_detection: bool = False
|
disable_nan_detection: bool = False
|
||||||
enable_overlap_schedule: bool = False
|
enable_overlap_schedule: bool = False
|
||||||
enable_mixed_chunk: bool = False
|
enable_mixed_chunk: bool = False
|
||||||
|
enable_dp_attention: bool = False
|
||||||
enable_torch_compile: bool = False
|
enable_torch_compile: bool = False
|
||||||
torch_compile_max_bs: int = 32
|
torch_compile_max_bs: int = 32
|
||||||
cuda_graph_max_bs: int = 160
|
cuda_graph_max_bs: int = 160
|
||||||
@@ -203,6 +204,16 @@ class ServerArgs:
|
|||||||
if self.sampling_backend is None:
|
if self.sampling_backend is None:
|
||||||
self.sampling_backend = "flashinfer"
|
self.sampling_backend = "flashinfer"
|
||||||
|
|
||||||
|
if self.enable_dp_attention:
|
||||||
|
self.dp_size = self.tp_size
|
||||||
|
self.chunked_prefill_size = self.chunked_prefill_size // 2
|
||||||
|
self.disable_cuda_graph = True
|
||||||
|
self.enable_overlap_schedule = False
|
||||||
|
logger.warning(
|
||||||
|
f"DP attention is enabled. The chunked prefill size is adjusted to {self.chunked_prefill_size} to avoid MoE workload issue. "
|
||||||
|
"The CUDA graph is disabled."
|
||||||
|
)
|
||||||
|
|
||||||
if self.enable_overlap_schedule:
|
if self.enable_overlap_schedule:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"Overlap scheduler mode is enabled. This is an experimental feature. "
|
"Overlap scheduler mode is enabled. This is an experimental feature. "
|
||||||
@@ -669,6 +680,11 @@ class ServerArgs:
|
|||||||
action="store_true",
|
action="store_true",
|
||||||
help="Enabling mixing prefill and decode in a batch when using chunked prefill.",
|
help="Enabling mixing prefill and decode in a batch when using chunked prefill.",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--enable-dp-attention",
|
||||||
|
action="store_true",
|
||||||
|
help="Enabling data parallelism for attention and tensor parallelism for FFN. The dp size should be equal to the tp size. Currently only DeepSeek-V2 is supported.",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--enable-torch-compile",
|
"--enable-torch-compile",
|
||||||
action="store_true",
|
action="store_true",
|
||||||
|
|||||||
63
test/srt/test_dp_attention.py
Normal file
63
test/srt/test_dp_attention.py
Normal file
@@ -0,0 +1,63 @@
|
|||||||
|
import unittest
|
||||||
|
from types import SimpleNamespace
|
||||||
|
|
||||||
|
from sglang.srt.utils import kill_child_process
|
||||||
|
from sglang.test.run_eval import run_eval
|
||||||
|
from sglang.test.test_utils import (
|
||||||
|
DEFAULT_MLA_MODEL_NAME_FOR_TEST,
|
||||||
|
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||||
|
DEFAULT_URL_FOR_TEST,
|
||||||
|
popen_launch_server,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestDPAttention(unittest.TestCase):
|
||||||
|
@classmethod
|
||||||
|
def setUpClass(cls):
|
||||||
|
cls.model = DEFAULT_MLA_MODEL_NAME_FOR_TEST
|
||||||
|
cls.base_url = DEFAULT_URL_FOR_TEST
|
||||||
|
cls.process = popen_launch_server(
|
||||||
|
cls.model,
|
||||||
|
cls.base_url,
|
||||||
|
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||||
|
other_args=[
|
||||||
|
"--trust-remote-code",
|
||||||
|
"--tp",
|
||||||
|
"2",
|
||||||
|
"--dp",
|
||||||
|
"2",
|
||||||
|
"--enable-dp-attention",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def tearDownClass(cls):
|
||||||
|
kill_child_process(cls.process.pid, include_self=True)
|
||||||
|
|
||||||
|
def test_mmlu(self):
|
||||||
|
args = SimpleNamespace(
|
||||||
|
base_url=self.base_url,
|
||||||
|
model=self.model,
|
||||||
|
eval_name="mmlu",
|
||||||
|
num_examples=64,
|
||||||
|
num_threads=32,
|
||||||
|
)
|
||||||
|
|
||||||
|
metrics = run_eval(args)
|
||||||
|
assert metrics["score"] >= 0.5
|
||||||
|
|
||||||
|
def test_mgsm_en(self):
|
||||||
|
args = SimpleNamespace(
|
||||||
|
base_url=self.base_url,
|
||||||
|
model=self.model,
|
||||||
|
eval_name="mgsm_en",
|
||||||
|
num_examples=None,
|
||||||
|
num_threads=1024,
|
||||||
|
)
|
||||||
|
|
||||||
|
metrics = run_eval(args)
|
||||||
|
assert metrics["score"] >= 0.8
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
||||||
Reference in New Issue
Block a user