Fix the deadlock in multi-node tp (#1122)
This commit is contained in:
@@ -67,10 +67,12 @@ class LogitsMetadata:
|
||||
|
||||
|
||||
class LogitsProcessor(nn.Module):
|
||||
def __init__(self, config):
|
||||
def __init__(self, config, skip_all_gather: bool = False):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
self.do_tensor_parallel_all_gather = (
|
||||
not skip_all_gather and get_tensor_model_parallel_world_size() > 1
|
||||
)
|
||||
|
||||
def _get_normalized_prompt_logprobs(
|
||||
self, input_token_logprobs, logits_metadata: LogitsMetadata
|
||||
@@ -159,7 +161,7 @@ class LogitsProcessor(nn.Module):
|
||||
last_hidden = hidden_states[last_index]
|
||||
|
||||
last_logits = torch.matmul(last_hidden, weight.T)
|
||||
if self.tp_size > 1:
|
||||
if self.do_tensor_parallel_all_gather:
|
||||
last_logits = tensor_model_parallel_all_gather(last_logits)
|
||||
last_logits = last_logits[:, : self.config.vocab_size].float()
|
||||
|
||||
@@ -204,7 +206,7 @@ class LogitsProcessor(nn.Module):
|
||||
)
|
||||
else:
|
||||
all_logits = torch.matmul(hidden_states, weight.T)
|
||||
if self.tp_size > 1:
|
||||
if self.do_tensor_parallel_all_gather:
|
||||
all_logits = tensor_model_parallel_all_gather(all_logits)
|
||||
all_logits = all_logits[:, : self.config.vocab_size].float()
|
||||
|
||||
|
||||
@@ -21,7 +21,9 @@ from dataclasses import dataclass
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from flashinfer.sampling import top_k_top_p_sampling_from_probs
|
||||
from vllm.distributed import get_tensor_model_parallel_group
|
||||
|
||||
import sglang.srt.sampling.penaltylib as penaltylib
|
||||
from sglang.global_config import global_config
|
||||
@@ -724,7 +726,7 @@ class ScheduleBatch:
|
||||
)
|
||||
self.logit_bias = torch.concat([self.logit_bias, other.logit_bias])
|
||||
|
||||
def sample(self, logits: torch.Tensor):
|
||||
def sample(self, logits: torch.Tensor, is_multi_node_tp=False):
|
||||
# TODO(lsyin): move this into a part of layer and run with CUDA Graph
|
||||
# Post process logits
|
||||
logits = logits.contiguous()
|
||||
@@ -779,6 +781,16 @@ class ScheduleBatch:
|
||||
|
||||
self.penalizer_orchestrator.cumulate_output_tokens(batch_next_token_ids)
|
||||
|
||||
if is_multi_node_tp:
|
||||
# If the tensor parallelism spans across multiple nodes, there is some indeterminism
|
||||
# that can cause the TP workers to generate different tokens, so we need to
|
||||
# sync here
|
||||
torch.distributed.all_reduce(
|
||||
batch_next_token_ids,
|
||||
op=dist.ReduceOp.MIN,
|
||||
group=get_tensor_model_parallel_group().device_group,
|
||||
)
|
||||
|
||||
return batch_next_token_ids
|
||||
|
||||
|
||||
|
||||
@@ -85,10 +85,6 @@ class ModelTpServer:
|
||||
self.schedule_policy = server_args.schedule_policy
|
||||
self.disable_regex_jump_forward = server_args.disable_regex_jump_forward
|
||||
|
||||
# Chunked prefill
|
||||
self.chunked_prefill_size = server_args.chunked_prefill_size
|
||||
self.current_inflight_req = None
|
||||
|
||||
# Init model and tokenizer
|
||||
self.model_config = ModelConfig(
|
||||
server_args.model_path,
|
||||
@@ -175,6 +171,10 @@ class ModelTpServer:
|
||||
self.num_generated_tokens = 0
|
||||
self.last_stats_tic = time.time()
|
||||
|
||||
# Chunked prefill
|
||||
self.chunked_prefill_size = server_args.chunked_prefill_size
|
||||
self.current_inflight_req = None
|
||||
|
||||
# Init the FSM cache for constrained generation
|
||||
if not server_args.skip_tokenizer_init:
|
||||
self.regex_fsm_cache = FSMCache(
|
||||
@@ -444,7 +444,9 @@ class ModelTpServer:
|
||||
# Forward and sample the next tokens
|
||||
if batch.extend_num_tokens != 0:
|
||||
output = self.model_runner.forward(batch, ForwardMode.EXTEND)
|
||||
next_token_ids = batch.sample(output.next_token_logits)
|
||||
next_token_ids = batch.sample(
|
||||
output.next_token_logits, self.model_runner.is_multi_node_tp
|
||||
)
|
||||
|
||||
# Move logprobs to cpu
|
||||
if output.next_token_logprobs is not None:
|
||||
@@ -603,7 +605,9 @@ class ModelTpServer:
|
||||
|
||||
# Forward and sample the next tokens
|
||||
output = self.model_runner.forward(batch, ForwardMode.DECODE)
|
||||
next_token_ids = batch.sample(output.next_token_logits)
|
||||
next_token_ids = batch.sample(
|
||||
output.next_token_logits, self.model_runner.is_multi_node_tp
|
||||
)
|
||||
|
||||
# Move logprobs to cpu
|
||||
if output.next_token_logprobs is not None:
|
||||
|
||||
@@ -142,7 +142,7 @@ class CudaGraphRunner:
|
||||
set_torch_compile_config()
|
||||
|
||||
def can_run(self, batch_size):
|
||||
return batch_size < self.max_bs
|
||||
return batch_size <= self.max_bs
|
||||
|
||||
def capture(self, batch_size_list):
|
||||
self.batch_size_list = batch_size_list
|
||||
@@ -239,12 +239,23 @@ class CudaGraphRunner:
|
||||
return forward(input_ids, input_metadata.positions, input_metadata)
|
||||
|
||||
for _ in range(2):
|
||||
torch.cuda.synchronize()
|
||||
self.model_runner.tp_group.barrier()
|
||||
|
||||
run_once()
|
||||
|
||||
torch.cuda.synchronize()
|
||||
self.model_runner.tp_group.barrier()
|
||||
|
||||
torch.cuda.synchronize()
|
||||
self.model_runner.tp_group.barrier()
|
||||
|
||||
with torch.cuda.graph(graph, pool=self.graph_memory_pool, stream=stream):
|
||||
out = run_once()
|
||||
|
||||
torch.cuda.synchronize()
|
||||
self.model_runner.tp_group.barrier()
|
||||
|
||||
self.graph_memory_pool = graph.pool()
|
||||
return graph, None, out, flashinfer_decode_wrapper
|
||||
|
||||
@@ -278,7 +289,9 @@ class CudaGraphRunner:
|
||||
)
|
||||
|
||||
# Replay
|
||||
torch.cuda.synchronize()
|
||||
self.graphs[bs].replay()
|
||||
torch.cuda.synchronize()
|
||||
output = self.output_buffers[bs]
|
||||
|
||||
# Unpad
|
||||
|
||||
@@ -38,6 +38,7 @@ from vllm.distributed import (
|
||||
init_distributed_environment,
|
||||
initialize_model_parallel,
|
||||
)
|
||||
from vllm.distributed.parallel_state import in_the_same_node_as
|
||||
from vllm.model_executor.model_loader import get_model
|
||||
from vllm.model_executor.models import ModelRegistry
|
||||
|
||||
@@ -112,10 +113,13 @@ class ModelRunner:
|
||||
distributed_init_method=nccl_init_method,
|
||||
)
|
||||
initialize_model_parallel(tensor_model_parallel_size=self.tp_size)
|
||||
self.tp_group = get_tp_group()
|
||||
total_gpu_memory = get_available_gpu_memory(
|
||||
self.gpu_id, distributed=self.tp_size > 1
|
||||
)
|
||||
self.tp_group = get_tp_group()
|
||||
self.is_multi_node_tp = not all(
|
||||
in_the_same_node_as(self.tp_group.cpu_group, source_rank=0)
|
||||
)
|
||||
|
||||
if self.tp_size > 1:
|
||||
total_local_gpu_memory = get_available_gpu_memory(self.gpu_id)
|
||||
|
||||
@@ -295,8 +295,9 @@ class Grok1ModelForCausalLM(nn.Module):
|
||||
self.config = config
|
||||
self.quant_config = quant_config
|
||||
self.model = Grok1Model(config, quant_config=quant_config)
|
||||
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
||||
self.logits_processor = LogitsProcessor(config)
|
||||
# self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
||||
self.lm_head = ReplicatedLinear(config.hidden_size, config.vocab_size)
|
||||
self.logits_processor = LogitsProcessor(config, skip_all_gather=True)
|
||||
|
||||
# Monkey patch _prepare_weights to load pre-sharded weights
|
||||
setattr(DefaultModelLoader, "_prepare_weights", _prepare_presharded_weights)
|
||||
|
||||
Reference in New Issue
Block a user