Add more support for intel Gaudi accelerators (#2357)
This commit is contained in:
@@ -1,7 +1,13 @@
|
|||||||
|
import argparse
|
||||||
|
import dataclasses
|
||||||
|
|
||||||
import sglang as sgl
|
import sglang as sgl
|
||||||
|
from sglang.srt.server_args import ServerArgs
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main(
|
||||||
|
server_args: ServerArgs,
|
||||||
|
):
|
||||||
# Sample prompts.
|
# Sample prompts.
|
||||||
prompts = [
|
prompts = [
|
||||||
"Hello, my name is",
|
"Hello, my name is",
|
||||||
@@ -13,7 +19,7 @@ def main():
|
|||||||
sampling_params = {"temperature": 0.8, "top_p": 0.95}
|
sampling_params = {"temperature": 0.8, "top_p": 0.95}
|
||||||
|
|
||||||
# Create an LLM.
|
# Create an LLM.
|
||||||
llm = sgl.Engine(model_path="meta-llama/Meta-Llama-3.1-8B-Instruct")
|
llm = sgl.Engine(**dataclasses.asdict(server_args))
|
||||||
|
|
||||||
outputs = llm.generate(prompts, sampling_params)
|
outputs = llm.generate(prompts, sampling_params)
|
||||||
# Print the outputs.
|
# Print the outputs.
|
||||||
@@ -25,4 +31,8 @@ def main():
|
|||||||
# The __main__ condition is necessary here because we use "spawn" to create subprocesses
|
# The __main__ condition is necessary here because we use "spawn" to create subprocesses
|
||||||
# Spawn starts a fresh program every time, if there is no __main__, it will run into infinite loop to keep spawning processes from sgl.Engine
|
# Spawn starts a fresh program every time, if there is no __main__, it will run into infinite loop to keep spawning processes from sgl.Engine
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
parser = argparse.ArgumentParser()
|
||||||
|
ServerArgs.add_cli_args(parser)
|
||||||
|
args = parser.parse_args()
|
||||||
|
server_args = ServerArgs.from_cli_args(args)
|
||||||
|
main(server_args)
|
||||||
|
|||||||
@@ -111,5 +111,7 @@ def top_k_top_p_min_p_sampling_from_probs_torch(
|
|||||||
probs_sort[probs_sort < min_p_thresholds.view(-1, 1)] = 0.0
|
probs_sort[probs_sort < min_p_thresholds.view(-1, 1)] = 0.0
|
||||||
probs_sort.div_(probs_sort.max(dim=-1, keepdim=True)[0])
|
probs_sort.div_(probs_sort.max(dim=-1, keepdim=True)[0])
|
||||||
sampled_index = torch.multinomial(probs_sort, num_samples=1)
|
sampled_index = torch.multinomial(probs_sort, num_samples=1)
|
||||||
|
# int32 range is enough to represent the token ids
|
||||||
|
probs_idx = probs_idx.to(torch.int32)
|
||||||
batch_next_token_ids = torch.gather(probs_idx, dim=1, index=sampled_index).view(-1)
|
batch_next_token_ids = torch.gather(probs_idx, dim=1, index=sampled_index).view(-1)
|
||||||
return batch_next_token_ids
|
return batch_next_token_ids
|
||||||
|
|||||||
@@ -993,7 +993,7 @@ class Scheduler:
|
|||||||
self.process_batch_result_prefill(batch, result)
|
self.process_batch_result_prefill(batch, result)
|
||||||
elif batch.forward_mode.is_dummy_first():
|
elif batch.forward_mode.is_dummy_first():
|
||||||
batch.next_batch_sampling_info.update_regex_vocab_mask()
|
batch.next_batch_sampling_info.update_regex_vocab_mask()
|
||||||
torch.cuda.current_stream().synchronize()
|
torch.get_device_module(self.device).current_stream().synchronize()
|
||||||
batch.next_batch_sampling_info.sampling_info_done.set()
|
batch.next_batch_sampling_info.sampling_info_done.set()
|
||||||
|
|
||||||
def process_batch_result_prefill(self, batch: ScheduleBatch, result):
|
def process_batch_result_prefill(self, batch: ScheduleBatch, result):
|
||||||
@@ -1055,7 +1055,7 @@ class Scheduler:
|
|||||||
|
|
||||||
if batch.next_batch_sampling_info:
|
if batch.next_batch_sampling_info:
|
||||||
batch.next_batch_sampling_info.update_regex_vocab_mask()
|
batch.next_batch_sampling_info.update_regex_vocab_mask()
|
||||||
torch.cuda.current_stream().synchronize()
|
torch.get_device_module(self.device).current_stream().synchronize()
|
||||||
batch.next_batch_sampling_info.sampling_info_done.set()
|
batch.next_batch_sampling_info.sampling_info_done.set()
|
||||||
|
|
||||||
else: # embedding or reward model
|
else: # embedding or reward model
|
||||||
@@ -1130,7 +1130,7 @@ class Scheduler:
|
|||||||
|
|
||||||
if batch.next_batch_sampling_info:
|
if batch.next_batch_sampling_info:
|
||||||
batch.next_batch_sampling_info.update_regex_vocab_mask()
|
batch.next_batch_sampling_info.update_regex_vocab_mask()
|
||||||
torch.cuda.current_stream().synchronize()
|
torch.get_device_module(self.device).current_stream().synchronize()
|
||||||
batch.next_batch_sampling_info.sampling_info_done.set()
|
batch.next_batch_sampling_info.sampling_info_done.set()
|
||||||
|
|
||||||
self.stream_output(batch.reqs)
|
self.stream_output(batch.reqs)
|
||||||
|
|||||||
@@ -32,12 +32,13 @@ from sglang.srt.managers.io_struct import (
|
|||||||
from sglang.srt.managers.schedule_batch import ModelWorkerBatch
|
from sglang.srt.managers.schedule_batch import ModelWorkerBatch
|
||||||
from sglang.srt.managers.tp_worker import TpModelWorker
|
from sglang.srt.managers.tp_worker import TpModelWorker
|
||||||
from sglang.srt.server_args import ServerArgs
|
from sglang.srt.server_args import ServerArgs
|
||||||
|
from sglang.srt.utils import get_compiler_backend
|
||||||
from sglang.utils import get_exception_traceback
|
from sglang.utils import get_exception_traceback
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@torch.compile(dynamic=True)
|
@torch.compile(dynamic=True, backend=get_compiler_backend())
|
||||||
def resolve_future_token_ids(input_ids, future_token_ids_map):
|
def resolve_future_token_ids(input_ids, future_token_ids_map):
|
||||||
input_ids[:] = torch.where(
|
input_ids[:] = torch.where(
|
||||||
input_ids < 0,
|
input_ids < 0,
|
||||||
@@ -73,7 +74,7 @@ class TpModelWorkerClient:
|
|||||||
# Launch threads
|
# Launch threads
|
||||||
self.input_queue = Queue()
|
self.input_queue = Queue()
|
||||||
self.output_queue = Queue()
|
self.output_queue = Queue()
|
||||||
self.forward_stream = torch.cuda.Stream()
|
self.forward_stream = torch.get_device_module(self.device).Stream()
|
||||||
self.forward_thread = threading.Thread(
|
self.forward_thread = threading.Thread(
|
||||||
target=self.forward_thread_func,
|
target=self.forward_thread_func,
|
||||||
)
|
)
|
||||||
@@ -97,7 +98,7 @@ class TpModelWorkerClient:
|
|||||||
|
|
||||||
def forward_thread_func(self):
|
def forward_thread_func(self):
|
||||||
try:
|
try:
|
||||||
with torch.cuda.stream(self.forward_stream):
|
with torch.get_device_module(self.device).stream(self.forward_stream):
|
||||||
self.forward_thread_func_()
|
self.forward_thread_func_()
|
||||||
except Exception:
|
except Exception:
|
||||||
traceback = get_exception_traceback()
|
traceback = get_exception_traceback()
|
||||||
@@ -122,7 +123,7 @@ class TpModelWorkerClient:
|
|||||||
|
|
||||||
# Create event
|
# Create event
|
||||||
self.launch_done = threading.Event()
|
self.launch_done = threading.Event()
|
||||||
copy_done = torch.cuda.Event()
|
copy_done = torch.get_device_module(self.device).Event()
|
||||||
|
|
||||||
# Resolve future tokens in the input
|
# Resolve future tokens in the input
|
||||||
input_ids = model_worker_batch.input_ids
|
input_ids = model_worker_batch.input_ids
|
||||||
@@ -190,7 +191,7 @@ class TpModelWorkerClient:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# A cuda stream sync here to avoid the cuda illegal memory access error.
|
# A cuda stream sync here to avoid the cuda illegal memory access error.
|
||||||
torch.cuda.current_stream().synchronize()
|
torch.get_device_module(self.device).current_stream().synchronize()
|
||||||
|
|
||||||
# Push a new batch to the queue
|
# Push a new batch to the queue
|
||||||
self.input_queue.put((model_worker_batch, self.future_token_ids_ct))
|
self.input_queue.put((model_worker_batch, self.future_token_ids_ct))
|
||||||
|
|||||||
@@ -27,6 +27,7 @@ from typing import List, Tuple, Union
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from sglang.srt.layers.radix_attention import RadixAttention
|
from sglang.srt.layers.radix_attention import RadixAttention
|
||||||
|
from sglang.srt.utils import get_compiler_backend
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -129,6 +130,9 @@ class BaseTokenToKVPool:
|
|||||||
return select_index.to(self.device, non_blocking=True)
|
return select_index.to(self.device, non_blocking=True)
|
||||||
|
|
||||||
def free(self, free_index: torch.Tensor):
|
def free(self, free_index: torch.Tensor):
|
||||||
|
if free_index.numel() == 0:
|
||||||
|
return
|
||||||
|
|
||||||
if self.is_not_in_free_group:
|
if self.is_not_in_free_group:
|
||||||
self.free_slots = torch.concat((self.free_slots, free_index.cpu()))
|
self.free_slots = torch.concat((self.free_slots, free_index.cpu()))
|
||||||
else:
|
else:
|
||||||
@@ -234,7 +238,7 @@ class MHATokenToKVPool(BaseTokenToKVPool):
|
|||||||
|
|
||||||
# This compiled version is slower in the unit test
|
# This compiled version is slower in the unit test
|
||||||
# python3 -m unittest test_bench_serving.TestBenchServing.test_offline_throughput_non_stream_small_batch_size
|
# python3 -m unittest test_bench_serving.TestBenchServing.test_offline_throughput_non_stream_small_batch_size
|
||||||
@torch.compile(dynamic=True)
|
@torch.compile(dynamic=True, backend=get_compiler_backend())
|
||||||
def copy_two_array(loc, dst_1, src_1, dst_2, src_2, dtype, store_dtype):
|
def copy_two_array(loc, dst_1, src_1, dst_2, src_2, dtype, store_dtype):
|
||||||
dst_1[loc] = src_1.to(dtype).view(store_dtype)
|
dst_1[loc] = src_1.to(dtype).view(store_dtype)
|
||||||
dst_2[loc] = src_2.to(dtype).view(store_dtype)
|
dst_2[loc] = src_2.to(dtype).view(store_dtype)
|
||||||
|
|||||||
@@ -62,10 +62,10 @@ from sglang.srt.layers.radix_attention import RadixAttention
|
|||||||
from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
||||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||||
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||||
from sglang.srt.utils import set_weight_attrs
|
from sglang.srt.utils import get_compiler_backend, set_weight_attrs
|
||||||
|
|
||||||
|
|
||||||
@torch.compile
|
@torch.compile(backend=get_compiler_backend())
|
||||||
def layer_norm_func(hidden_states, weight, variance_epsilon):
|
def layer_norm_func(hidden_states, weight, variance_epsilon):
|
||||||
input_dtype = hidden_states.dtype
|
input_dtype = hidden_states.dtype
|
||||||
hidden_states = hidden_states.to(torch.float32)
|
hidden_states = hidden_states.to(torch.float32)
|
||||||
|
|||||||
@@ -25,6 +25,7 @@ import torch
|
|||||||
from sglang.srt.hf_transformers_utils import check_gguf_file
|
from sglang.srt.hf_transformers_utils import check_gguf_file
|
||||||
from sglang.srt.utils import (
|
from sglang.srt.utils import (
|
||||||
get_amdgpu_memory_capacity,
|
get_amdgpu_memory_capacity,
|
||||||
|
get_hpu_memory_capacity,
|
||||||
get_nvgpu_memory_capacity,
|
get_nvgpu_memory_capacity,
|
||||||
is_flashinfer_available,
|
is_flashinfer_available,
|
||||||
is_hip,
|
is_hip,
|
||||||
@@ -158,6 +159,8 @@ class ServerArgs:
|
|||||||
gpu_mem = get_amdgpu_memory_capacity()
|
gpu_mem = get_amdgpu_memory_capacity()
|
||||||
elif torch.cuda.is_available():
|
elif torch.cuda.is_available():
|
||||||
gpu_mem = get_nvgpu_memory_capacity()
|
gpu_mem = get_nvgpu_memory_capacity()
|
||||||
|
elif self.device == "hpu":
|
||||||
|
gpu_mem = get_hpu_memory_capacity()
|
||||||
else:
|
else:
|
||||||
# GPU memory is not known yet or no GPU is available.
|
# GPU memory is not known yet or no GPU is available.
|
||||||
gpu_mem = None
|
gpu_mem = None
|
||||||
@@ -194,6 +197,10 @@ class ServerArgs:
|
|||||||
self.cuda_graph_max_bs = 160
|
self.cuda_graph_max_bs = 160
|
||||||
|
|
||||||
# Choose kernel backends
|
# Choose kernel backends
|
||||||
|
if self.device == "hpu":
|
||||||
|
self.attention_backend = "torch_native"
|
||||||
|
self.sampling_backend = "pytorch"
|
||||||
|
|
||||||
if self.attention_backend is None:
|
if self.attention_backend is None:
|
||||||
self.attention_backend = (
|
self.attention_backend = (
|
||||||
"flashinfer" if is_flashinfer_available() else "triton"
|
"flashinfer" if is_flashinfer_available() else "triton"
|
||||||
|
|||||||
@@ -201,6 +201,18 @@ def get_available_gpu_memory(device, gpu_id, distributed=False):
|
|||||||
total_gpu_memory = torch.xpu.get_device_properties(gpu_id).total_memory
|
total_gpu_memory = torch.xpu.get_device_properties(gpu_id).total_memory
|
||||||
free_gpu_memory = total_gpu_memory - used_memory
|
free_gpu_memory = total_gpu_memory - used_memory
|
||||||
|
|
||||||
|
elif device == "hpu":
|
||||||
|
num_gpus = torch.hpu.device_count()
|
||||||
|
assert gpu_id < num_gpus
|
||||||
|
|
||||||
|
if torch.hpu.current_device() != gpu_id:
|
||||||
|
print(
|
||||||
|
f"WARNING: current device is not {gpu_id}, but {torch.hpu.current_device()}, ",
|
||||||
|
"which may cause useless memory allocation for torch HPU context.",
|
||||||
|
)
|
||||||
|
|
||||||
|
free_gpu_memory, total_gpu_memory = torch.hpu.mem_get_info()
|
||||||
|
|
||||||
if distributed:
|
if distributed:
|
||||||
tensor = torch.tensor(free_gpu_memory, dtype=torch.float32).to(
|
tensor = torch.tensor(free_gpu_memory, dtype=torch.float32).to(
|
||||||
torch.device(device, gpu_id)
|
torch.device(device, gpu_id)
|
||||||
@@ -939,6 +951,37 @@ def get_nvgpu_memory_capacity():
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_hpu_memory_capacity():
|
||||||
|
try:
|
||||||
|
# Run hl-smi and capture the output
|
||||||
|
result = subprocess.run(
|
||||||
|
["hl-smi --query | grep 'Total'"],
|
||||||
|
stdout=subprocess.PIPE,
|
||||||
|
stderr=subprocess.PIPE,
|
||||||
|
shell=True,
|
||||||
|
text=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
if result.returncode != 0:
|
||||||
|
raise RuntimeError(f"hl-smi error: {result.stderr.strip()}")
|
||||||
|
|
||||||
|
# Parse the output to extract memory values in MiB
|
||||||
|
memory_values = [
|
||||||
|
float(mem.split(" ")[-2]) for mem in result.stdout.strip().split("\n")
|
||||||
|
]
|
||||||
|
|
||||||
|
if not memory_values:
|
||||||
|
raise ValueError("No GPU memory values found.")
|
||||||
|
|
||||||
|
# Return the minimum memory value
|
||||||
|
return min(memory_values)
|
||||||
|
|
||||||
|
except FileNotFoundError:
|
||||||
|
raise RuntimeError(
|
||||||
|
"hl-smi not found. Ensure Habana drivers are installed and accessible."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# Copy from pytorch and OpenRLHF to allow creating multiple main groups.
|
# Copy from pytorch and OpenRLHF to allow creating multiple main groups.
|
||||||
# https://github.com/pytorch/pytorch/blob/main/torch/distributed/distributed_c10d.py
|
# https://github.com/pytorch/pytorch/blob/main/torch/distributed/distributed_c10d.py
|
||||||
# https://github.com/OpenRLHF/OpenRLHF/blob/main/openrlhf/utils/distributed_util.py
|
# https://github.com/OpenRLHF/OpenRLHF/blob/main/openrlhf/utils/distributed_util.py
|
||||||
@@ -1062,6 +1105,13 @@ def get_device_capability(device_id: int = 0) -> Tuple[int, int]:
|
|||||||
return major, minor
|
return major, minor
|
||||||
|
|
||||||
|
|
||||||
|
def get_compiler_backend() -> str:
|
||||||
|
if hasattr(torch, "hpu") and torch.hpu.is_available():
|
||||||
|
return "hpu_backend"
|
||||||
|
|
||||||
|
return "inductor"
|
||||||
|
|
||||||
|
|
||||||
sglang_lib = Library("sglang", "FRAGMENT") # noqa
|
sglang_lib = Library("sglang", "FRAGMENT") # noqa
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user