Add Tensor Parallel to torch_native_llama (#1876)
This commit is contained in:
@@ -148,6 +148,15 @@ class ModelRunner:
|
||||
min_per_gpu_memory = self.init_torch_distributed()
|
||||
self.sampler = Sampler()
|
||||
self.load_model()
|
||||
|
||||
# Apply torch TP if model supports it
|
||||
supports_torch_tp = getattr(self.model, "supports_torch_tp", False)
|
||||
if self.tp_size > 1 and supports_torch_tp:
|
||||
self.apply_torch_tp()
|
||||
self.torch_tp_applied = True
|
||||
else:
|
||||
self.torch_tp_applied = False
|
||||
|
||||
if server_args.lora_paths is not None:
|
||||
self.init_lora_manager()
|
||||
self.init_memory_pool(
|
||||
@@ -551,6 +560,13 @@ class ModelRunner:
|
||||
logger.info("Capture cuda graph begin. This can take up to several minutes.")
|
||||
self.cuda_graph_runner = CudaGraphRunner(self)
|
||||
|
||||
def apply_torch_tp(self):
|
||||
logger.info(f"Enabling torch tensor parallelism on {self.tp_size} devices.")
|
||||
from sglang.srt.model_parallel import tensor_parallel
|
||||
|
||||
device_mesh = torch.distributed.init_device_mesh(self.device, (self.tp_size,))
|
||||
tensor_parallel(self.model, device_mesh)
|
||||
|
||||
def forward_decode(self, forward_batch: ForwardBatch):
|
||||
if self.cuda_graph_runner and self.cuda_graph_runner.can_run(forward_batch):
|
||||
return self.cuda_graph_runner.replay(forward_batch)
|
||||
|
||||
Reference in New Issue
Block a user