Add more support for intel Gaudi accelerators (#2357)
This commit is contained in:
@@ -111,5 +111,7 @@ def top_k_top_p_min_p_sampling_from_probs_torch(
|
||||
probs_sort[probs_sort < min_p_thresholds.view(-1, 1)] = 0.0
|
||||
probs_sort.div_(probs_sort.max(dim=-1, keepdim=True)[0])
|
||||
sampled_index = torch.multinomial(probs_sort, num_samples=1)
|
||||
# int32 range is enough to represent the token ids
|
||||
probs_idx = probs_idx.to(torch.int32)
|
||||
batch_next_token_ids = torch.gather(probs_idx, dim=1, index=sampled_index).view(-1)
|
||||
return batch_next_token_ids
|
||||
|
||||
@@ -993,7 +993,7 @@ class Scheduler:
|
||||
self.process_batch_result_prefill(batch, result)
|
||||
elif batch.forward_mode.is_dummy_first():
|
||||
batch.next_batch_sampling_info.update_regex_vocab_mask()
|
||||
torch.cuda.current_stream().synchronize()
|
||||
torch.get_device_module(self.device).current_stream().synchronize()
|
||||
batch.next_batch_sampling_info.sampling_info_done.set()
|
||||
|
||||
def process_batch_result_prefill(self, batch: ScheduleBatch, result):
|
||||
@@ -1055,7 +1055,7 @@ class Scheduler:
|
||||
|
||||
if batch.next_batch_sampling_info:
|
||||
batch.next_batch_sampling_info.update_regex_vocab_mask()
|
||||
torch.cuda.current_stream().synchronize()
|
||||
torch.get_device_module(self.device).current_stream().synchronize()
|
||||
batch.next_batch_sampling_info.sampling_info_done.set()
|
||||
|
||||
else: # embedding or reward model
|
||||
@@ -1130,7 +1130,7 @@ class Scheduler:
|
||||
|
||||
if batch.next_batch_sampling_info:
|
||||
batch.next_batch_sampling_info.update_regex_vocab_mask()
|
||||
torch.cuda.current_stream().synchronize()
|
||||
torch.get_device_module(self.device).current_stream().synchronize()
|
||||
batch.next_batch_sampling_info.sampling_info_done.set()
|
||||
|
||||
self.stream_output(batch.reqs)
|
||||
|
||||
@@ -32,12 +32,13 @@ from sglang.srt.managers.io_struct import (
|
||||
from sglang.srt.managers.schedule_batch import ModelWorkerBatch
|
||||
from sglang.srt.managers.tp_worker import TpModelWorker
|
||||
from sglang.srt.server_args import ServerArgs
|
||||
from sglang.srt.utils import get_compiler_backend
|
||||
from sglang.utils import get_exception_traceback
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@torch.compile(dynamic=True)
|
||||
@torch.compile(dynamic=True, backend=get_compiler_backend())
|
||||
def resolve_future_token_ids(input_ids, future_token_ids_map):
|
||||
input_ids[:] = torch.where(
|
||||
input_ids < 0,
|
||||
@@ -73,7 +74,7 @@ class TpModelWorkerClient:
|
||||
# Launch threads
|
||||
self.input_queue = Queue()
|
||||
self.output_queue = Queue()
|
||||
self.forward_stream = torch.cuda.Stream()
|
||||
self.forward_stream = torch.get_device_module(self.device).Stream()
|
||||
self.forward_thread = threading.Thread(
|
||||
target=self.forward_thread_func,
|
||||
)
|
||||
@@ -97,7 +98,7 @@ class TpModelWorkerClient:
|
||||
|
||||
def forward_thread_func(self):
|
||||
try:
|
||||
with torch.cuda.stream(self.forward_stream):
|
||||
with torch.get_device_module(self.device).stream(self.forward_stream):
|
||||
self.forward_thread_func_()
|
||||
except Exception:
|
||||
traceback = get_exception_traceback()
|
||||
@@ -122,7 +123,7 @@ class TpModelWorkerClient:
|
||||
|
||||
# Create event
|
||||
self.launch_done = threading.Event()
|
||||
copy_done = torch.cuda.Event()
|
||||
copy_done = torch.get_device_module(self.device).Event()
|
||||
|
||||
# Resolve future tokens in the input
|
||||
input_ids = model_worker_batch.input_ids
|
||||
@@ -190,7 +191,7 @@ class TpModelWorkerClient:
|
||||
)
|
||||
|
||||
# A cuda stream sync here to avoid the cuda illegal memory access error.
|
||||
torch.cuda.current_stream().synchronize()
|
||||
torch.get_device_module(self.device).current_stream().synchronize()
|
||||
|
||||
# Push a new batch to the queue
|
||||
self.input_queue.put((model_worker_batch, self.future_token_ids_ct))
|
||||
|
||||
@@ -27,6 +27,7 @@ from typing import List, Tuple, Union
|
||||
import torch
|
||||
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from sglang.srt.utils import get_compiler_backend
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -129,6 +130,9 @@ class BaseTokenToKVPool:
|
||||
return select_index.to(self.device, non_blocking=True)
|
||||
|
||||
def free(self, free_index: torch.Tensor):
|
||||
if free_index.numel() == 0:
|
||||
return
|
||||
|
||||
if self.is_not_in_free_group:
|
||||
self.free_slots = torch.concat((self.free_slots, free_index.cpu()))
|
||||
else:
|
||||
@@ -234,7 +238,7 @@ class MHATokenToKVPool(BaseTokenToKVPool):
|
||||
|
||||
# This compiled version is slower in the unit test
|
||||
# python3 -m unittest test_bench_serving.TestBenchServing.test_offline_throughput_non_stream_small_batch_size
|
||||
@torch.compile(dynamic=True)
|
||||
@torch.compile(dynamic=True, backend=get_compiler_backend())
|
||||
def copy_two_array(loc, dst_1, src_1, dst_2, src_2, dtype, store_dtype):
|
||||
dst_1[loc] = src_1.to(dtype).view(store_dtype)
|
||||
dst_2[loc] = src_2.to(dtype).view(store_dtype)
|
||||
|
||||
@@ -62,10 +62,10 @@ from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||
from sglang.srt.utils import set_weight_attrs
|
||||
from sglang.srt.utils import get_compiler_backend, set_weight_attrs
|
||||
|
||||
|
||||
@torch.compile
|
||||
@torch.compile(backend=get_compiler_backend())
|
||||
def layer_norm_func(hidden_states, weight, variance_epsilon):
|
||||
input_dtype = hidden_states.dtype
|
||||
hidden_states = hidden_states.to(torch.float32)
|
||||
|
||||
@@ -25,6 +25,7 @@ import torch
|
||||
from sglang.srt.hf_transformers_utils import check_gguf_file
|
||||
from sglang.srt.utils import (
|
||||
get_amdgpu_memory_capacity,
|
||||
get_hpu_memory_capacity,
|
||||
get_nvgpu_memory_capacity,
|
||||
is_flashinfer_available,
|
||||
is_hip,
|
||||
@@ -158,6 +159,8 @@ class ServerArgs:
|
||||
gpu_mem = get_amdgpu_memory_capacity()
|
||||
elif torch.cuda.is_available():
|
||||
gpu_mem = get_nvgpu_memory_capacity()
|
||||
elif self.device == "hpu":
|
||||
gpu_mem = get_hpu_memory_capacity()
|
||||
else:
|
||||
# GPU memory is not known yet or no GPU is available.
|
||||
gpu_mem = None
|
||||
@@ -194,6 +197,10 @@ class ServerArgs:
|
||||
self.cuda_graph_max_bs = 160
|
||||
|
||||
# Choose kernel backends
|
||||
if self.device == "hpu":
|
||||
self.attention_backend = "torch_native"
|
||||
self.sampling_backend = "pytorch"
|
||||
|
||||
if self.attention_backend is None:
|
||||
self.attention_backend = (
|
||||
"flashinfer" if is_flashinfer_available() else "triton"
|
||||
|
||||
@@ -201,6 +201,18 @@ def get_available_gpu_memory(device, gpu_id, distributed=False):
|
||||
total_gpu_memory = torch.xpu.get_device_properties(gpu_id).total_memory
|
||||
free_gpu_memory = total_gpu_memory - used_memory
|
||||
|
||||
elif device == "hpu":
|
||||
num_gpus = torch.hpu.device_count()
|
||||
assert gpu_id < num_gpus
|
||||
|
||||
if torch.hpu.current_device() != gpu_id:
|
||||
print(
|
||||
f"WARNING: current device is not {gpu_id}, but {torch.hpu.current_device()}, ",
|
||||
"which may cause useless memory allocation for torch HPU context.",
|
||||
)
|
||||
|
||||
free_gpu_memory, total_gpu_memory = torch.hpu.mem_get_info()
|
||||
|
||||
if distributed:
|
||||
tensor = torch.tensor(free_gpu_memory, dtype=torch.float32).to(
|
||||
torch.device(device, gpu_id)
|
||||
@@ -939,6 +951,37 @@ def get_nvgpu_memory_capacity():
|
||||
)
|
||||
|
||||
|
||||
def get_hpu_memory_capacity():
|
||||
try:
|
||||
# Run hl-smi and capture the output
|
||||
result = subprocess.run(
|
||||
["hl-smi --query | grep 'Total'"],
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
shell=True,
|
||||
text=True,
|
||||
)
|
||||
|
||||
if result.returncode != 0:
|
||||
raise RuntimeError(f"hl-smi error: {result.stderr.strip()}")
|
||||
|
||||
# Parse the output to extract memory values in MiB
|
||||
memory_values = [
|
||||
float(mem.split(" ")[-2]) for mem in result.stdout.strip().split("\n")
|
||||
]
|
||||
|
||||
if not memory_values:
|
||||
raise ValueError("No GPU memory values found.")
|
||||
|
||||
# Return the minimum memory value
|
||||
return min(memory_values)
|
||||
|
||||
except FileNotFoundError:
|
||||
raise RuntimeError(
|
||||
"hl-smi not found. Ensure Habana drivers are installed and accessible."
|
||||
)
|
||||
|
||||
|
||||
# Copy from pytorch and OpenRLHF to allow creating multiple main groups.
|
||||
# https://github.com/pytorch/pytorch/blob/main/torch/distributed/distributed_c10d.py
|
||||
# https://github.com/OpenRLHF/OpenRLHF/blob/main/openrlhf/utils/distributed_util.py
|
||||
@@ -1062,6 +1105,13 @@ def get_device_capability(device_id: int = 0) -> Tuple[int, int]:
|
||||
return major, minor
|
||||
|
||||
|
||||
def get_compiler_backend() -> str:
|
||||
if hasattr(torch, "hpu") and torch.hpu.is_available():
|
||||
return "hpu_backend"
|
||||
|
||||
return "inductor"
|
||||
|
||||
|
||||
sglang_lib = Library("sglang", "FRAGMENT") # noqa
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user