Fix the deadlock in multi-node tp (#1122)
This commit is contained in:
@@ -64,7 +64,9 @@ def main(args):
|
|||||||
@sgl.function
|
@sgl.function
|
||||||
def few_shot_gsm8k(s, question):
|
def few_shot_gsm8k(s, question):
|
||||||
s += few_shot_examples + question
|
s += few_shot_examples + question
|
||||||
s += sgl.gen("answer", max_tokens=512, stop=["Question", "Assistant:"])
|
s += sgl.gen(
|
||||||
|
"answer", max_tokens=512, stop=["Question", "Assistant:", "<|separator|>"]
|
||||||
|
)
|
||||||
|
|
||||||
#####################################
|
#####################################
|
||||||
########## SGL Program End ##########
|
########## SGL Program End ##########
|
||||||
|
|||||||
@@ -67,10 +67,12 @@ class LogitsMetadata:
|
|||||||
|
|
||||||
|
|
||||||
class LogitsProcessor(nn.Module):
|
class LogitsProcessor(nn.Module):
|
||||||
def __init__(self, config):
|
def __init__(self, config, skip_all_gather: bool = False):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
self.tp_size = get_tensor_model_parallel_world_size()
|
self.do_tensor_parallel_all_gather = (
|
||||||
|
not skip_all_gather and get_tensor_model_parallel_world_size() > 1
|
||||||
|
)
|
||||||
|
|
||||||
def _get_normalized_prompt_logprobs(
|
def _get_normalized_prompt_logprobs(
|
||||||
self, input_token_logprobs, logits_metadata: LogitsMetadata
|
self, input_token_logprobs, logits_metadata: LogitsMetadata
|
||||||
@@ -159,7 +161,7 @@ class LogitsProcessor(nn.Module):
|
|||||||
last_hidden = hidden_states[last_index]
|
last_hidden = hidden_states[last_index]
|
||||||
|
|
||||||
last_logits = torch.matmul(last_hidden, weight.T)
|
last_logits = torch.matmul(last_hidden, weight.T)
|
||||||
if self.tp_size > 1:
|
if self.do_tensor_parallel_all_gather:
|
||||||
last_logits = tensor_model_parallel_all_gather(last_logits)
|
last_logits = tensor_model_parallel_all_gather(last_logits)
|
||||||
last_logits = last_logits[:, : self.config.vocab_size].float()
|
last_logits = last_logits[:, : self.config.vocab_size].float()
|
||||||
|
|
||||||
@@ -204,7 +206,7 @@ class LogitsProcessor(nn.Module):
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
all_logits = torch.matmul(hidden_states, weight.T)
|
all_logits = torch.matmul(hidden_states, weight.T)
|
||||||
if self.tp_size > 1:
|
if self.do_tensor_parallel_all_gather:
|
||||||
all_logits = tensor_model_parallel_all_gather(all_logits)
|
all_logits = tensor_model_parallel_all_gather(all_logits)
|
||||||
all_logits = all_logits[:, : self.config.vocab_size].float()
|
all_logits = all_logits[:, : self.config.vocab_size].float()
|
||||||
|
|
||||||
|
|||||||
@@ -21,7 +21,9 @@ from dataclasses import dataclass
|
|||||||
from typing import List, Optional, Union
|
from typing import List, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
from flashinfer.sampling import top_k_top_p_sampling_from_probs
|
from flashinfer.sampling import top_k_top_p_sampling_from_probs
|
||||||
|
from vllm.distributed import get_tensor_model_parallel_group
|
||||||
|
|
||||||
import sglang.srt.sampling.penaltylib as penaltylib
|
import sglang.srt.sampling.penaltylib as penaltylib
|
||||||
from sglang.global_config import global_config
|
from sglang.global_config import global_config
|
||||||
@@ -724,7 +726,7 @@ class ScheduleBatch:
|
|||||||
)
|
)
|
||||||
self.logit_bias = torch.concat([self.logit_bias, other.logit_bias])
|
self.logit_bias = torch.concat([self.logit_bias, other.logit_bias])
|
||||||
|
|
||||||
def sample(self, logits: torch.Tensor):
|
def sample(self, logits: torch.Tensor, is_multi_node_tp=False):
|
||||||
# TODO(lsyin): move this into a part of layer and run with CUDA Graph
|
# TODO(lsyin): move this into a part of layer and run with CUDA Graph
|
||||||
# Post process logits
|
# Post process logits
|
||||||
logits = logits.contiguous()
|
logits = logits.contiguous()
|
||||||
@@ -779,6 +781,16 @@ class ScheduleBatch:
|
|||||||
|
|
||||||
self.penalizer_orchestrator.cumulate_output_tokens(batch_next_token_ids)
|
self.penalizer_orchestrator.cumulate_output_tokens(batch_next_token_ids)
|
||||||
|
|
||||||
|
if is_multi_node_tp:
|
||||||
|
# If the tensor parallelism spans across multiple nodes, there is some indeterminism
|
||||||
|
# that can cause the TP workers to generate different tokens, so we need to
|
||||||
|
# sync here
|
||||||
|
torch.distributed.all_reduce(
|
||||||
|
batch_next_token_ids,
|
||||||
|
op=dist.ReduceOp.MIN,
|
||||||
|
group=get_tensor_model_parallel_group().device_group,
|
||||||
|
)
|
||||||
|
|
||||||
return batch_next_token_ids
|
return batch_next_token_ids
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -85,10 +85,6 @@ class ModelTpServer:
|
|||||||
self.schedule_policy = server_args.schedule_policy
|
self.schedule_policy = server_args.schedule_policy
|
||||||
self.disable_regex_jump_forward = server_args.disable_regex_jump_forward
|
self.disable_regex_jump_forward = server_args.disable_regex_jump_forward
|
||||||
|
|
||||||
# Chunked prefill
|
|
||||||
self.chunked_prefill_size = server_args.chunked_prefill_size
|
|
||||||
self.current_inflight_req = None
|
|
||||||
|
|
||||||
# Init model and tokenizer
|
# Init model and tokenizer
|
||||||
self.model_config = ModelConfig(
|
self.model_config = ModelConfig(
|
||||||
server_args.model_path,
|
server_args.model_path,
|
||||||
@@ -175,6 +171,10 @@ class ModelTpServer:
|
|||||||
self.num_generated_tokens = 0
|
self.num_generated_tokens = 0
|
||||||
self.last_stats_tic = time.time()
|
self.last_stats_tic = time.time()
|
||||||
|
|
||||||
|
# Chunked prefill
|
||||||
|
self.chunked_prefill_size = server_args.chunked_prefill_size
|
||||||
|
self.current_inflight_req = None
|
||||||
|
|
||||||
# Init the FSM cache for constrained generation
|
# Init the FSM cache for constrained generation
|
||||||
if not server_args.skip_tokenizer_init:
|
if not server_args.skip_tokenizer_init:
|
||||||
self.regex_fsm_cache = FSMCache(
|
self.regex_fsm_cache = FSMCache(
|
||||||
@@ -444,7 +444,9 @@ class ModelTpServer:
|
|||||||
# Forward and sample the next tokens
|
# Forward and sample the next tokens
|
||||||
if batch.extend_num_tokens != 0:
|
if batch.extend_num_tokens != 0:
|
||||||
output = self.model_runner.forward(batch, ForwardMode.EXTEND)
|
output = self.model_runner.forward(batch, ForwardMode.EXTEND)
|
||||||
next_token_ids = batch.sample(output.next_token_logits)
|
next_token_ids = batch.sample(
|
||||||
|
output.next_token_logits, self.model_runner.is_multi_node_tp
|
||||||
|
)
|
||||||
|
|
||||||
# Move logprobs to cpu
|
# Move logprobs to cpu
|
||||||
if output.next_token_logprobs is not None:
|
if output.next_token_logprobs is not None:
|
||||||
@@ -603,7 +605,9 @@ class ModelTpServer:
|
|||||||
|
|
||||||
# Forward and sample the next tokens
|
# Forward and sample the next tokens
|
||||||
output = self.model_runner.forward(batch, ForwardMode.DECODE)
|
output = self.model_runner.forward(batch, ForwardMode.DECODE)
|
||||||
next_token_ids = batch.sample(output.next_token_logits)
|
next_token_ids = batch.sample(
|
||||||
|
output.next_token_logits, self.model_runner.is_multi_node_tp
|
||||||
|
)
|
||||||
|
|
||||||
# Move logprobs to cpu
|
# Move logprobs to cpu
|
||||||
if output.next_token_logprobs is not None:
|
if output.next_token_logprobs is not None:
|
||||||
|
|||||||
@@ -142,7 +142,7 @@ class CudaGraphRunner:
|
|||||||
set_torch_compile_config()
|
set_torch_compile_config()
|
||||||
|
|
||||||
def can_run(self, batch_size):
|
def can_run(self, batch_size):
|
||||||
return batch_size < self.max_bs
|
return batch_size <= self.max_bs
|
||||||
|
|
||||||
def capture(self, batch_size_list):
|
def capture(self, batch_size_list):
|
||||||
self.batch_size_list = batch_size_list
|
self.batch_size_list = batch_size_list
|
||||||
@@ -239,12 +239,23 @@ class CudaGraphRunner:
|
|||||||
return forward(input_ids, input_metadata.positions, input_metadata)
|
return forward(input_ids, input_metadata.positions, input_metadata)
|
||||||
|
|
||||||
for _ in range(2):
|
for _ in range(2):
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
self.model_runner.tp_group.barrier()
|
||||||
|
|
||||||
run_once()
|
run_once()
|
||||||
|
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
self.model_runner.tp_group.barrier()
|
||||||
|
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
|
self.model_runner.tp_group.barrier()
|
||||||
|
|
||||||
with torch.cuda.graph(graph, pool=self.graph_memory_pool, stream=stream):
|
with torch.cuda.graph(graph, pool=self.graph_memory_pool, stream=stream):
|
||||||
out = run_once()
|
out = run_once()
|
||||||
|
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
|
self.model_runner.tp_group.barrier()
|
||||||
|
|
||||||
self.graph_memory_pool = graph.pool()
|
self.graph_memory_pool = graph.pool()
|
||||||
return graph, None, out, flashinfer_decode_wrapper
|
return graph, None, out, flashinfer_decode_wrapper
|
||||||
|
|
||||||
@@ -278,7 +289,9 @@ class CudaGraphRunner:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Replay
|
# Replay
|
||||||
|
torch.cuda.synchronize()
|
||||||
self.graphs[bs].replay()
|
self.graphs[bs].replay()
|
||||||
|
torch.cuda.synchronize()
|
||||||
output = self.output_buffers[bs]
|
output = self.output_buffers[bs]
|
||||||
|
|
||||||
# Unpad
|
# Unpad
|
||||||
|
|||||||
@@ -38,6 +38,7 @@ from vllm.distributed import (
|
|||||||
init_distributed_environment,
|
init_distributed_environment,
|
||||||
initialize_model_parallel,
|
initialize_model_parallel,
|
||||||
)
|
)
|
||||||
|
from vllm.distributed.parallel_state import in_the_same_node_as
|
||||||
from vllm.model_executor.model_loader import get_model
|
from vllm.model_executor.model_loader import get_model
|
||||||
from vllm.model_executor.models import ModelRegistry
|
from vllm.model_executor.models import ModelRegistry
|
||||||
|
|
||||||
@@ -112,10 +113,13 @@ class ModelRunner:
|
|||||||
distributed_init_method=nccl_init_method,
|
distributed_init_method=nccl_init_method,
|
||||||
)
|
)
|
||||||
initialize_model_parallel(tensor_model_parallel_size=self.tp_size)
|
initialize_model_parallel(tensor_model_parallel_size=self.tp_size)
|
||||||
self.tp_group = get_tp_group()
|
|
||||||
total_gpu_memory = get_available_gpu_memory(
|
total_gpu_memory = get_available_gpu_memory(
|
||||||
self.gpu_id, distributed=self.tp_size > 1
|
self.gpu_id, distributed=self.tp_size > 1
|
||||||
)
|
)
|
||||||
|
self.tp_group = get_tp_group()
|
||||||
|
self.is_multi_node_tp = not all(
|
||||||
|
in_the_same_node_as(self.tp_group.cpu_group, source_rank=0)
|
||||||
|
)
|
||||||
|
|
||||||
if self.tp_size > 1:
|
if self.tp_size > 1:
|
||||||
total_local_gpu_memory = get_available_gpu_memory(self.gpu_id)
|
total_local_gpu_memory = get_available_gpu_memory(self.gpu_id)
|
||||||
|
|||||||
@@ -295,8 +295,9 @@ class Grok1ModelForCausalLM(nn.Module):
|
|||||||
self.config = config
|
self.config = config
|
||||||
self.quant_config = quant_config
|
self.quant_config = quant_config
|
||||||
self.model = Grok1Model(config, quant_config=quant_config)
|
self.model = Grok1Model(config, quant_config=quant_config)
|
||||||
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
# self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
||||||
self.logits_processor = LogitsProcessor(config)
|
self.lm_head = ReplicatedLinear(config.hidden_size, config.vocab_size)
|
||||||
|
self.logits_processor = LogitsProcessor(config, skip_all_gather=True)
|
||||||
|
|
||||||
# Monkey patch _prepare_weights to load pre-sharded weights
|
# Monkey patch _prepare_weights to load pre-sharded weights
|
||||||
setattr(DefaultModelLoader, "_prepare_weights", _prepare_presharded_weights)
|
setattr(DefaultModelLoader, "_prepare_weights", _prepare_presharded_weights)
|
||||||
|
|||||||
Reference in New Issue
Block a user