Add device support (#1607)
This commit is contained in:
@@ -423,6 +423,9 @@ class ScheduleBatch:
|
|||||||
# Stream
|
# Stream
|
||||||
has_stream: bool = False
|
has_stream: bool = False
|
||||||
|
|
||||||
|
# device
|
||||||
|
device: str = "cuda"
|
||||||
|
|
||||||
# Has regex
|
# Has regex
|
||||||
has_regex: bool = False
|
has_regex: bool = False
|
||||||
|
|
||||||
@@ -439,6 +442,7 @@ class ScheduleBatch:
|
|||||||
tree_cache=tree_cache,
|
tree_cache=tree_cache,
|
||||||
return_logprob=return_logprob,
|
return_logprob=return_logprob,
|
||||||
has_stream=has_stream,
|
has_stream=has_stream,
|
||||||
|
device=req_to_token_pool.device,
|
||||||
has_regex=has_regex,
|
has_regex=has_regex,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -81,10 +81,11 @@ class ModelRunner:
|
|||||||
# Parse args
|
# Parse args
|
||||||
self.model_config = model_config
|
self.model_config = model_config
|
||||||
self.mem_fraction_static = mem_fraction_static
|
self.mem_fraction_static = mem_fraction_static
|
||||||
|
self.device = server_args.device
|
||||||
self.gpu_id = gpu_id
|
self.gpu_id = gpu_id
|
||||||
self.tp_rank = tp_rank
|
self.tp_rank = tp_rank
|
||||||
self.tp_size = tp_size
|
self.tp_size = tp_size
|
||||||
self.nccl_port = nccl_port
|
self.dist_port = nccl_port
|
||||||
self.server_args = server_args
|
self.server_args = server_args
|
||||||
self.is_multimodal_model = is_multimodal_model(
|
self.is_multimodal_model = is_multimodal_model(
|
||||||
self.model_config.hf_config.architectures
|
self.model_config.hf_config.architectures
|
||||||
@@ -132,39 +133,45 @@ class ModelRunner:
|
|||||||
server_args.max_running_requests,
|
server_args.max_running_requests,
|
||||||
server_args.max_total_tokens,
|
server_args.max_total_tokens,
|
||||||
)
|
)
|
||||||
self.init_cublas()
|
if self.device == "cuda":
|
||||||
self.init_attention_backend()
|
self.init_cublas()
|
||||||
self.init_cuda_graphs()
|
self.init_attention_backend()
|
||||||
|
self.init_cuda_graphs()
|
||||||
|
else:
|
||||||
|
self.init_attention_backend()
|
||||||
|
|
||||||
def init_torch_distributed(self):
|
def init_torch_distributed(self):
|
||||||
|
logger.info("Init torch distributed begin.")
|
||||||
# Init torch distributed
|
# Init torch distributed
|
||||||
torch.cuda.set_device(self.gpu_id)
|
if self.device == "cuda":
|
||||||
logger.info("Init nccl begin.")
|
torch.cuda.set_device(self.gpu_id)
|
||||||
|
backend = "nccl"
|
||||||
|
|
||||||
if not self.server_args.enable_p2p_check:
|
if not self.server_args.enable_p2p_check:
|
||||||
monkey_patch_vllm_p2p_access_check(self.gpu_id)
|
monkey_patch_vllm_p2p_access_check(self.gpu_id)
|
||||||
|
|
||||||
if self.server_args.dist_init_addr:
|
if self.server_args.dist_init_addr:
|
||||||
nccl_init_method = f"tcp://{self.server_args.dist_init_addr}"
|
dist_init_method = f"tcp://{self.server_args.dist_init_addr}"
|
||||||
else:
|
else:
|
||||||
nccl_init_method = f"tcp://127.0.0.1:{self.nccl_port}"
|
dist_init_method = f"tcp://127.0.0.1:{self.dist_port}"
|
||||||
set_custom_all_reduce(not self.server_args.disable_custom_all_reduce)
|
set_custom_all_reduce(not self.server_args.disable_custom_all_reduce)
|
||||||
init_distributed_environment(
|
init_distributed_environment(
|
||||||
backend="nccl",
|
backend=backend,
|
||||||
world_size=self.tp_size,
|
world_size=self.tp_size,
|
||||||
rank=self.tp_rank,
|
rank=self.tp_rank,
|
||||||
local_rank=self.gpu_id,
|
local_rank=self.gpu_id,
|
||||||
distributed_init_method=nccl_init_method,
|
distributed_init_method=dist_init_method,
|
||||||
)
|
)
|
||||||
initialize_model_parallel(tensor_model_parallel_size=self.tp_size)
|
initialize_model_parallel(tensor_model_parallel_size=self.tp_size)
|
||||||
min_per_gpu_memory = get_available_gpu_memory(
|
min_per_gpu_memory = get_available_gpu_memory(
|
||||||
self.gpu_id, distributed=self.tp_size > 1
|
self.device, self.gpu_id, distributed=self.tp_size > 1
|
||||||
)
|
)
|
||||||
self.tp_group = get_tp_group()
|
self.tp_group = get_tp_group()
|
||||||
|
|
||||||
# Currently, there is a bug with mulit-node tensor parallelsim + padded cuda graph,
|
# Currently, there is a bug with mulit-node tensor parallelsim + padded cuda graph,
|
||||||
# so we disable padding in cuda graph.
|
# so we disable padding in cuda graph.
|
||||||
if not all(in_the_same_node_as(self.tp_group.cpu_group, source_rank=0)):
|
if self.device == "cuda" and not all(
|
||||||
|
in_the_same_node_as(self.tp_group.cpu_group, source_rank=0)
|
||||||
|
):
|
||||||
self.server_args.disable_cuda_graph_padding = True
|
self.server_args.disable_cuda_graph_padding = True
|
||||||
logger.info(
|
logger.info(
|
||||||
"Setting disable_cuda_graph_padding to True because of multi-node tensor parallelism."
|
"Setting disable_cuda_graph_padding to True because of multi-node tensor parallelism."
|
||||||
@@ -172,7 +179,7 @@ class ModelRunner:
|
|||||||
|
|
||||||
# Check memory for tensor parallelism
|
# Check memory for tensor parallelism
|
||||||
if self.tp_size > 1:
|
if self.tp_size > 1:
|
||||||
local_gpu_memory = get_available_gpu_memory(self.gpu_id)
|
local_gpu_memory = get_available_gpu_memory(self.device, self.gpu_id)
|
||||||
if min_per_gpu_memory < local_gpu_memory * 0.9:
|
if min_per_gpu_memory < local_gpu_memory * 0.9:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"The memory capacity is unbalanced. Some GPUs may be occupied by other processes."
|
"The memory capacity is unbalanced. Some GPUs may be occupied by other processes."
|
||||||
@@ -182,23 +189,22 @@ class ModelRunner:
|
|||||||
|
|
||||||
def load_model(self):
|
def load_model(self):
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Load weight begin. avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB"
|
f"Load weight begin. avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
|
||||||
)
|
)
|
||||||
|
|
||||||
# This can reduce thread conflicts and speed up weight loading.
|
# This can reduce thread conflicts and speed up weight loading.
|
||||||
torch.set_num_threads(1)
|
torch.set_num_threads(1)
|
||||||
|
if self.device == "cuda":
|
||||||
if torch.cuda.get_device_capability()[0] < 8:
|
if torch.cuda.get_device_capability()[0] < 8:
|
||||||
logger.info(
|
logger.info(
|
||||||
"Compute capability below sm80. Use float16 due to lack of bfloat16 support."
|
"Compute capability below sm80. Use float16 due to lack of bfloat16 support."
|
||||||
)
|
)
|
||||||
self.server_args.dtype = "float16"
|
self.server_args.dtype = "float16"
|
||||||
if torch.cuda.get_device_capability()[1] < 5:
|
if torch.cuda.get_device_capability()[1] < 5:
|
||||||
raise RuntimeError("SGLang only supports sm75 and above.")
|
raise RuntimeError("SGLang only supports sm75 and above.")
|
||||||
|
|
||||||
# Prepare the vllm model config
|
# Prepare the vllm model config
|
||||||
monkey_patch_vllm_dummy_weight_loader()
|
monkey_patch_vllm_dummy_weight_loader()
|
||||||
self.device_config = DeviceConfig()
|
|
||||||
self.load_config = LoadConfig(load_format=self.server_args.load_format)
|
self.load_config = LoadConfig(load_format=self.server_args.load_format)
|
||||||
self.vllm_model_config = VllmModelConfig(
|
self.vllm_model_config = VllmModelConfig(
|
||||||
model=self.server_args.model_path,
|
model=self.server_args.model_path,
|
||||||
@@ -220,7 +226,7 @@ class ModelRunner:
|
|||||||
self.model = get_model(
|
self.model = get_model(
|
||||||
model_config=self.vllm_model_config,
|
model_config=self.vllm_model_config,
|
||||||
load_config=self.load_config,
|
load_config=self.load_config,
|
||||||
device_config=self.device_config,
|
device_config=DeviceConfig(self.device),
|
||||||
parallel_config=None,
|
parallel_config=None,
|
||||||
scheduler_config=None,
|
scheduler_config=None,
|
||||||
lora_config=None,
|
lora_config=None,
|
||||||
@@ -240,7 +246,7 @@ class ModelRunner:
|
|||||||
f"Load weight end. "
|
f"Load weight end. "
|
||||||
f"type={type(self.model).__name__}, "
|
f"type={type(self.model).__name__}, "
|
||||||
f"dtype={self.dtype}, "
|
f"dtype={self.dtype}, "
|
||||||
f"avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB"
|
f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
|
||||||
)
|
)
|
||||||
|
|
||||||
def update_weights(self, model_path: str, load_format: str):
|
def update_weights(self, model_path: str, load_format: str):
|
||||||
@@ -254,10 +260,10 @@ class ModelRunner:
|
|||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Update weights begin. "
|
f"Update weights begin. "
|
||||||
f"avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB"
|
f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
|
||||||
)
|
)
|
||||||
|
|
||||||
target_device = torch.device(self.device_config.device)
|
target_device = torch.device(self.device)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# TODO: Use a better method to check this
|
# TODO: Use a better method to check this
|
||||||
@@ -343,7 +349,7 @@ class ModelRunner:
|
|||||||
|
|
||||||
def profile_max_num_token(self, total_gpu_memory: int):
|
def profile_max_num_token(self, total_gpu_memory: int):
|
||||||
available_gpu_memory = get_available_gpu_memory(
|
available_gpu_memory = get_available_gpu_memory(
|
||||||
self.gpu_id, distributed=self.tp_size > 1
|
self.device, self.gpu_id, distributed=self.tp_size > 1
|
||||||
)
|
)
|
||||||
if (
|
if (
|
||||||
self.model_config.attention_arch == AttentionArch.MLA
|
self.model_config.attention_arch == AttentionArch.MLA
|
||||||
@@ -409,11 +415,10 @@ class ModelRunner:
|
|||||||
4096,
|
4096,
|
||||||
)
|
)
|
||||||
|
|
||||||
device = "cuda"
|
|
||||||
self.req_to_token_pool = ReqToTokenPool(
|
self.req_to_token_pool = ReqToTokenPool(
|
||||||
size=max_num_reqs + 1,
|
size=max_num_reqs + 1,
|
||||||
max_context_len=self.model_config.context_len + 4,
|
max_context_len=self.model_config.context_len + 4,
|
||||||
device=device,
|
device=self.device,
|
||||||
)
|
)
|
||||||
if (
|
if (
|
||||||
self.model_config.attention_arch == AttentionArch.MLA
|
self.model_config.attention_arch == AttentionArch.MLA
|
||||||
@@ -425,7 +430,7 @@ class ModelRunner:
|
|||||||
kv_lora_rank=self.model_config.kv_lora_rank,
|
kv_lora_rank=self.model_config.kv_lora_rank,
|
||||||
qk_rope_head_dim=self.model_config.qk_rope_head_dim,
|
qk_rope_head_dim=self.model_config.qk_rope_head_dim,
|
||||||
layer_num=self.model_config.num_hidden_layers,
|
layer_num=self.model_config.num_hidden_layers,
|
||||||
device=device,
|
device=self.device,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self.token_to_kv_pool = MHATokenToKVPool(
|
self.token_to_kv_pool = MHATokenToKVPool(
|
||||||
@@ -434,11 +439,11 @@ class ModelRunner:
|
|||||||
head_num=self.model_config.get_num_kv_heads(self.tp_size),
|
head_num=self.model_config.get_num_kv_heads(self.tp_size),
|
||||||
head_dim=self.model_config.head_dim,
|
head_dim=self.model_config.head_dim,
|
||||||
layer_num=self.model_config.num_hidden_layers,
|
layer_num=self.model_config.num_hidden_layers,
|
||||||
device=device,
|
device=self.device,
|
||||||
)
|
)
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Memory pool end. "
|
f"Memory pool end. "
|
||||||
f"avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB"
|
f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
|
||||||
)
|
)
|
||||||
|
|
||||||
def init_cublas(self):
|
def init_cublas(self):
|
||||||
|
|||||||
@@ -37,6 +37,9 @@ class SamplingBatchInfo:
|
|||||||
linear_penalties: torch.Tensor = None
|
linear_penalties: torch.Tensor = None
|
||||||
scaling_penalties: torch.Tensor = None
|
scaling_penalties: torch.Tensor = None
|
||||||
|
|
||||||
|
# Device
|
||||||
|
device: str = "cuda"
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_schedule_batch(cls, batch: ScheduleBatch, vocab_size: int):
|
def from_schedule_batch(cls, batch: ScheduleBatch, vocab_size: int):
|
||||||
reqs = batch.reqs
|
reqs = batch.reqs
|
||||||
@@ -62,6 +65,7 @@ class SamplingBatchInfo:
|
|||||||
min_ps=min_ps,
|
min_ps=min_ps,
|
||||||
need_min_p_sampling=any(r.sampling_params.min_p > 0 for r in reqs),
|
need_min_p_sampling=any(r.sampling_params.min_p > 0 for r in reqs),
|
||||||
vocab_size=vocab_size,
|
vocab_size=vocab_size,
|
||||||
|
device=batch.input_ids.device,
|
||||||
)
|
)
|
||||||
# TODO (lianmin): `need_min_p_sampling` needs to be updated in filter and merge.
|
# TODO (lianmin): `need_min_p_sampling` needs to be updated in filter and merge.
|
||||||
|
|
||||||
@@ -75,7 +79,7 @@ class SamplingBatchInfo:
|
|||||||
ret.penalizer_orchestrator = penaltylib.BatchedPenalizerOrchestrator(
|
ret.penalizer_orchestrator = penaltylib.BatchedPenalizerOrchestrator(
|
||||||
vocab_size=vocab_size,
|
vocab_size=vocab_size,
|
||||||
batch=batch,
|
batch=batch,
|
||||||
device="cuda",
|
device=batch.input_ids.device,
|
||||||
Penalizers={
|
Penalizers={
|
||||||
penaltylib.BatchedFrequencyPenalizer,
|
penaltylib.BatchedFrequencyPenalizer,
|
||||||
penaltylib.BatchedMinNewTokensPenalizer,
|
penaltylib.BatchedMinNewTokensPenalizer,
|
||||||
@@ -107,7 +111,7 @@ class SamplingBatchInfo:
|
|||||||
self.linear_penalties = torch.zeros(
|
self.linear_penalties = torch.zeros(
|
||||||
(bs, self.vocab_size),
|
(bs, self.vocab_size),
|
||||||
dtype=torch.float32,
|
dtype=torch.float32,
|
||||||
device="cuda",
|
device=self.device,
|
||||||
)
|
)
|
||||||
self.linear_penalties = penalizer.apply(self.linear_penalties)
|
self.linear_penalties = penalizer.apply(self.linear_penalties)
|
||||||
|
|
||||||
@@ -119,7 +123,10 @@ class SamplingBatchInfo:
|
|||||||
|
|
||||||
if has_regex:
|
if has_regex:
|
||||||
self.vocab_mask = torch.zeros(
|
self.vocab_mask = torch.zeros(
|
||||||
len(self.temperatures), self.vocab_size, dtype=torch.bool, device="cuda"
|
len(self.temperatures),
|
||||||
|
self.vocab_size,
|
||||||
|
dtype=torch.bool,
|
||||||
|
device=self.device,
|
||||||
)
|
)
|
||||||
for i, regex_fsm in enumerate(self.regex_fsms):
|
for i, regex_fsm in enumerate(self.regex_fsms):
|
||||||
if regex_fsm is not None:
|
if regex_fsm is not None:
|
||||||
@@ -144,7 +151,12 @@ class SamplingBatchInfo:
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def merge_bias_tensor(
|
def merge_bias_tensor(
|
||||||
lhs: torch.Tensor, rhs: torch.Tensor, bs1: int, bs2: int, default: int = 0
|
lhs: torch.Tensor,
|
||||||
|
rhs: torch.Tensor,
|
||||||
|
bs1: int,
|
||||||
|
bs2: int,
|
||||||
|
device: str,
|
||||||
|
default: int = 0,
|
||||||
):
|
):
|
||||||
# bias tensor can be None
|
# bias tensor can be None
|
||||||
if lhs is not None or rhs is not None:
|
if lhs is not None or rhs is not None:
|
||||||
@@ -155,9 +167,9 @@ class SamplingBatchInfo:
|
|||||||
shape, dtype = rhs.shape[1:], rhs.dtype
|
shape, dtype = rhs.shape[1:], rhs.dtype
|
||||||
with torch.dtype(dtype):
|
with torch.dtype(dtype):
|
||||||
if lhs is None:
|
if lhs is None:
|
||||||
lhs = torch.empty((bs1, *shape), device="cuda").fill_(default)
|
lhs = torch.empty((bs1, *shape), device=device).fill_(default)
|
||||||
if rhs is None:
|
if rhs is None:
|
||||||
rhs = torch.empty((bs2, *shape), device="cuda").fill_(default)
|
rhs = torch.empty((bs2, *shape), device=device).fill_(default)
|
||||||
return torch.cat([lhs, rhs])
|
return torch.cat([lhs, rhs])
|
||||||
|
|
||||||
return None
|
return None
|
||||||
@@ -176,5 +188,5 @@ class SamplingBatchInfo:
|
|||||||
setattr(self, item, torch.concat([self_val, other_val]))
|
setattr(self, item, torch.concat([self_val, other_val]))
|
||||||
|
|
||||||
self.logit_bias = SamplingBatchInfo.merge_bias_tensor(
|
self.logit_bias = SamplingBatchInfo.merge_bias_tensor(
|
||||||
self.logit_bias, other.logit_bias, len(self), len(other)
|
self.logit_bias, other.logit_bias, len(self), len(other), self.device
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -36,6 +36,7 @@ class ServerArgs:
|
|||||||
skip_tokenizer_init: bool = False
|
skip_tokenizer_init: bool = False
|
||||||
load_format: str = "auto"
|
load_format: str = "auto"
|
||||||
dtype: str = "auto"
|
dtype: str = "auto"
|
||||||
|
device: str = "cuda"
|
||||||
kv_cache_dtype: str = "auto"
|
kv_cache_dtype: str = "auto"
|
||||||
trust_remote_code: bool = True
|
trust_remote_code: bool = True
|
||||||
context_length: Optional[int] = None
|
context_length: Optional[int] = None
|
||||||
@@ -237,6 +238,13 @@ class ServerArgs:
|
|||||||
'* "float" is shorthand for FP32 precision.\n'
|
'* "float" is shorthand for FP32 precision.\n'
|
||||||
'* "float32" for FP32 precision.',
|
'* "float32" for FP32 precision.',
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--device",
|
||||||
|
type=str,
|
||||||
|
default="cuda",
|
||||||
|
choices=["cuda"],
|
||||||
|
help="The device type.",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--kv-cache-dtype",
|
"--kv-cache-dtype",
|
||||||
type=str,
|
type=str,
|
||||||
|
|||||||
@@ -140,26 +140,41 @@ def calculate_time(show=False, min_cost_ms=0.0):
|
|||||||
return wrapper
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
def get_available_gpu_memory(gpu_id, distributed=False):
|
def get_available_gpu_memory(device, gpu_id, distributed=False):
|
||||||
"""
|
"""
|
||||||
Get available memory for cuda:gpu_id device.
|
Get available memory for cuda:gpu_id device.
|
||||||
When distributed is True, the available memory is the minimum available memory of all GPUs.
|
When distributed is True, the available memory is the minimum available memory of all GPUs.
|
||||||
"""
|
"""
|
||||||
num_gpus = torch.cuda.device_count()
|
if device == "cuda":
|
||||||
assert gpu_id < num_gpus
|
num_gpus = torch.cuda.device_count()
|
||||||
|
assert gpu_id < num_gpus
|
||||||
|
|
||||||
if torch.cuda.current_device() != gpu_id:
|
if torch.cuda.current_device() != gpu_id:
|
||||||
print(
|
print(
|
||||||
f"WARNING: current device is not {gpu_id}, but {torch.cuda.current_device()}, ",
|
f"WARNING: current device is not {gpu_id}, but {torch.cuda.current_device()}, ",
|
||||||
"which may cause useless memory allocation for torch CUDA context.",
|
"which may cause useless memory allocation for torch CUDA context.",
|
||||||
)
|
)
|
||||||
|
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
free_gpu_memory, _ = torch.cuda.mem_get_info(gpu_id)
|
free_gpu_memory, _ = torch.cuda.mem_get_info(gpu_id)
|
||||||
|
|
||||||
|
elif device == "xpu":
|
||||||
|
num_gpus = torch.xpu.device_count()
|
||||||
|
assert gpu_id < num_gpus
|
||||||
|
|
||||||
|
if torch.xpu.current_device() != gpu_id:
|
||||||
|
print(
|
||||||
|
f"WARNING: current device is not {gpu_id}, but {torch.xpu.current_device()}, ",
|
||||||
|
"which may cause useless memory allocation for torch XPU context.",
|
||||||
|
)
|
||||||
|
torch.xpu.empty_cache()
|
||||||
|
used_memory = torch.xpu.memory_allocated()
|
||||||
|
total_gpu_memory = torch.xpu.get_device_properties(gpu_id).total_memory
|
||||||
|
free_gpu_memory = total_gpu_memory - used_memory
|
||||||
|
|
||||||
if distributed:
|
if distributed:
|
||||||
tensor = torch.tensor(free_gpu_memory, dtype=torch.float32).to(
|
tensor = torch.tensor(free_gpu_memory, dtype=torch.float32).to(
|
||||||
torch.device("cuda", gpu_id)
|
torch.device(device, gpu_id)
|
||||||
)
|
)
|
||||||
torch.distributed.all_reduce(tensor, op=torch.distributed.ReduceOp.MIN)
|
torch.distributed.all_reduce(tensor, op=torch.distributed.ReduceOp.MIN)
|
||||||
free_gpu_memory = tensor.item()
|
free_gpu_memory = tensor.item()
|
||||||
|
|||||||
Reference in New Issue
Block a user