Add Tensor Parallel to torch_native_llama (#1876)

This commit is contained in:
Ke Wen
2024-11-15 21:26:00 -08:00
committed by GitHub
parent e5c6715003
commit cf2489762b
5 changed files with 246 additions and 82 deletions

View File

@@ -148,6 +148,15 @@ class ModelRunner:
min_per_gpu_memory = self.init_torch_distributed()
self.sampler = Sampler()
self.load_model()
# Apply torch TP if model supports it
supports_torch_tp = getattr(self.model, "supports_torch_tp", False)
if self.tp_size > 1 and supports_torch_tp:
self.apply_torch_tp()
self.torch_tp_applied = True
else:
self.torch_tp_applied = False
if server_args.lora_paths is not None:
self.init_lora_manager()
self.init_memory_pool(
@@ -551,6 +560,13 @@ class ModelRunner:
logger.info("Capture cuda graph begin. This can take up to several minutes.")
self.cuda_graph_runner = CudaGraphRunner(self)
def apply_torch_tp(self):
logger.info(f"Enabling torch tensor parallelism on {self.tp_size} devices.")
from sglang.srt.model_parallel import tensor_parallel
device_mesh = torch.distributed.init_device_mesh(self.device, (self.tp_size,))
tensor_parallel(self.model, device_mesh)
def forward_decode(self, forward_batch: ForwardBatch):
if self.cuda_graph_runner and self.cuda_graph_runner.can_run(forward_batch):
return self.cuda_graph_runner.replay(forward_batch)