diff --git a/python/sglang/srt/managers/scheduler_profiler_mixin.py b/python/sglang/srt/managers/scheduler_profiler_mixin.py index 116a8d842..e7ac8452d 100644 --- a/python/sglang/srt/managers/scheduler_profiler_mixin.py +++ b/python/sglang/srt/managers/scheduler_profiler_mixin.py @@ -204,7 +204,7 @@ class SchedulerProfilerMixin: torch.distributed.barrier(self.tp_cpu_group) if self.tp_rank == 0: - from sglang.srt.utils import rpd_to_chrome_trace + from sglang.srt.rpd_utils import rpd_to_chrome_trace rpd_to_chrome_trace("trace.rpd", self.rpd_profile_path) self.rpd_profiler = None diff --git a/python/sglang/srt/rpd_utils.py b/python/sglang/srt/rpd_utils.py new file mode 100644 index 000000000..18b62d40f --- /dev/null +++ b/python/sglang/srt/rpd_utils.py @@ -0,0 +1,452 @@ +# https://raw.githubusercontent.com/ROCm/rocmProfileData/refs/heads/master/tools/rpd2tracing.py +# commit 92d13a08328625463e9ba944cece82fc5eea36e6 +def rpd_to_chrome_trace( + input_rpd, output_json=None, start="0%", end="100%", format="object" +): + import gzip + import sqlite3 + + if output_json is None: + import pathlib + + output_json = pathlib.PurePath(input_rpd).with_suffix(".trace.json.gz") + + connection = sqlite3.connect(input_rpd) + + outfile = gzip.open(output_json, "wt", encoding="utf-8") + + if format == "object": + outfile.write('{"traceEvents": ') + + outfile.write("[ {}\n") + + for row in connection.execute("select distinct gpuId from rocpd_op"): + try: + outfile.write( + ',{"name": "process_name", "ph": "M", "pid":"%s","args":{"name":"%s"}}\n' + % (row[0], "GPU" + str(row[0])) + ) + outfile.write( + ',{"name": "process_sort_index", "ph": "M", "pid":"%s","args":{"sort_index":"%s"}}\n' + % (row[0], row[0] + 1000000) + ) + except ValueError: + outfile.write("") + + for row in connection.execute("select distinct pid, tid from rocpd_api"): + try: + outfile.write( + ',{"name":"thread_name","ph":"M","pid":"%s","tid":"%s","args":{"name":"%s"}}\n' + % (row[0], row[1], "Hip " + str(row[1])) + ) + outfile.write( + ',{"name":"thread_sort_index","ph":"M","pid":"%s","tid":"%s","args":{"sort_index":"%s"}}\n' + % (row[0], row[1], row[1] * 2) + ) + except ValueError: + outfile.write("") + + try: + # FIXME - these aren't rendering correctly in chrome://tracing + for row in connection.execute("select distinct pid, tid from rocpd_hsaApi"): + try: + outfile.write( + ',{"name":"thread_name","ph":"M","pid":"%s","tid":"%s","args":{"name":"%s"}}\n' + % (row[0], row[1], "HSA " + str(row[1])) + ) + outfile.write( + ',{"name":"thread_sort_index","ph":"M","pid":"%s","tid":"%s","args":{"sort_index":"%s"}}\n' + % (row[0], row[1], row[1] * 2 - 1) + ) + except ValueError: + outfile.write("") + except: + pass + + rangeStringApi = "" + rangeStringOp = "" + rangeStringMonitor = "" + min_time = connection.execute("select MIN(start) from rocpd_api;").fetchall()[0][0] + max_time = connection.execute("select MAX(end) from rocpd_api;").fetchall()[0][0] + if min_time == None: + raise Exception("Trace file is empty.") + + print("Timestamps:") + print(f"\t first: \t{min_time/1000} us") + print(f"\t last: \t{max_time/1000} us") + print(f"\t duration: \t{(max_time-min_time) / 1000000000} seconds") + + start_time = min_time / 1000 + end_time = max_time / 1000 + + if start: + if "%" in start: + start_time = ( + (max_time - min_time) * (int(start.replace("%", "")) / 100) + min_time + ) / 1000 + else: + start_time = int(start) + rangeStringApi = "where rocpd_api.start/1000 >= %s" % (start_time) + rangeStringOp = "where rocpd_op.start/1000 >= %s" % (start_time) + rangeStringMonitor = "where start/1000 >= %s" % (start_time) + if end: + if "%" in end: + end_time = ( + (max_time - min_time) * (int(end.replace("%", "")) / 100) + min_time + ) / 1000 + else: + end_time = int(end) + + rangeStringApi = ( + rangeStringApi + " and rocpd_api.start/1000 <= %s" % (end_time) + if start != None + else "where rocpd_api.start/1000 <= %s" % (end_time) + ) + rangeStringOp = ( + rangeStringOp + " and rocpd_op.start/1000 <= %s" % (end_time) + if start != None + else "where rocpd_op.start/1000 <= %s" % (end_time) + ) + rangeStringMonitor = ( + rangeStringMonitor + " and start/1000 <= %s" % (end_time) + if start != None + else "where start/1000 <= %s" % (end_time) + ) + + print("\nFilter: %s" % (rangeStringApi)) + print(f"Output duration: {(end_time-start_time)/1000000} seconds") + + # Output Ops + + for row in connection.execute( + "select A.string as optype, B.string as description, gpuId, queueId, rocpd_op.start/1000.0, (rocpd_op.end-rocpd_op.start) / 1000.0 from rocpd_op INNER JOIN rocpd_string A on A.id = rocpd_op.opType_id INNER Join rocpd_string B on B.id = rocpd_op.description_id %s" + % (rangeStringOp) + ): + try: + name = row[0] if len(row[1]) == 0 else row[1] + outfile.write( + ',{"pid":"%s","tid":"%s","name":"%s","ts":"%s","dur":"%s","ph":"X","args":{"desc":"%s"}}\n' + % (row[2], row[3], name, row[4], row[5], row[0]) + ) + except ValueError: + outfile.write("") + + # Output Graph executions on GPU + try: + for row in connection.execute( + "select graphExec, gpuId, queueId, min(start)/1000.0, (max(end)-min(start))/1000.0, count(*) from rocpd_graphLaunchapi A join rocpd_api_ops B on B.api_id = A.api_ptr_id join rocpd_op C on C.id = B.op_id %s group by api_ptr_id" + % (rangeStringMonitor) + ): + try: + outfile.write( + ',{"pid":"%s","tid":"%s","name":"%s","ts":"%s","dur":"%s","ph":"X","args":{"kernels":"%s"}}\n' + % (row[1], row[2], f"Graph {row[0]}", row[3], row[4], row[5]) + ) + except ValueError: + outfile.write("") + except: + pass + + # Output apis + for row in connection.execute( + "select A.string as apiName, B.string as args, pid, tid, rocpd_api.start/1000.0, (rocpd_api.end-rocpd_api.start) / 1000.0, (rocpd_api.end != rocpd_api.start) as has_duration from rocpd_api INNER JOIN rocpd_string A on A.id = rocpd_api.apiName_id INNER Join rocpd_string B on B.id = rocpd_api.args_id %s order by rocpd_api.id" + % (rangeStringApi) + ): + try: + if row[0] == "UserMarker": + if row[6] == 0: # instantanuous "mark" messages + outfile.write( + ',{"pid":"%s","tid":"%s","name":"%s","ts":"%s","ph":"i","s":"p","args":{"desc":"%s"}}\n' + % ( + row[2], + row[3], + row[1].replace('"', ""), + row[4], + row[1].replace('"', ""), + ) + ) + else: + outfile.write( + ',{"pid":"%s","tid":"%s","name":"%s","ts":"%s","dur":"%s","ph":"X","args":{"desc":"%s"}}\n' + % ( + row[2], + row[3], + row[1].replace('"', ""), + row[4], + row[5], + row[1].replace('"', ""), + ) + ) + else: + outfile.write( + ',{"pid":"%s","tid":"%s","name":"%s","ts":"%s","dur":"%s","ph":"X","args":{"desc":"%s"}}\n' + % ( + row[2], + row[3], + row[0], + row[4], + row[5], + row[1].replace('"', "").replace("\t", ""), + ) + ) + except ValueError: + outfile.write("") + + # Output api->op linkage + for row in connection.execute( + "select rocpd_api_ops.id, pid, tid, gpuId, queueId, rocpd_api.end/1000.0 - 2, rocpd_op.start/1000.0 from rocpd_api_ops INNER JOIN rocpd_api on rocpd_api_ops.api_id = rocpd_api.id INNER JOIN rocpd_op on rocpd_api_ops.op_id = rocpd_op.id %s" + % (rangeStringApi) + ): + try: + fromtime = row[5] if row[5] < row[6] else row[6] + outfile.write( + ',{"pid":"%s","tid":"%s","cat":"api_op","name":"api_op","ts":"%s","id":"%s","ph":"s"}\n' + % (row[1], row[2], fromtime, row[0]) + ) + outfile.write( + ',{"pid":"%s","tid":"%s","cat":"api_op","name":"api_op","ts":"%s","id":"%s","ph":"f", "bp":"e"}\n' + % (row[3], row[4], row[6], row[0]) + ) + except ValueError: + outfile.write("") + + try: + for row in connection.execute( + "select A.string as apiName, B.string as args, pid, tid, rocpd_hsaApi.start/1000.0, (rocpd_hsaApi.end-rocpd_hsaApi.start) / 1000.0 from rocpd_hsaApi INNER JOIN rocpd_string A on A.id = rocpd_hsaApi.apiName_id INNER Join rocpd_string B on B.id = rocpd_hsaApi.args_id %s order by rocpd_hsaApi.id" + % (rangeStringApi) + ): + try: + outfile.write( + ',{"pid":"%s","tid":"%s","name":"%s","ts":"%s","dur":"%s","ph":"X","args":{"desc":"%s"}}\n' + % ( + row[2], + row[3] + 1, + row[0], + row[4], + row[5], + row[1].replace('"', ""), + ) + ) + except ValueError: + outfile.write("") + except: + pass + + # + # Counters + # + + # Counters should extend to the last event in the trace. This means they need to have a value at Tend. + # Figure out when that is + + T_end = 0 + for row in connection.execute( + "SELECT max(end)/1000 from (SELECT end from rocpd_api UNION ALL SELECT end from rocpd_op)" + ): + T_end = int(row[0]) + if end: + T_end = end_time + + # Loop over GPU for per-gpu counters + gpuIdsPresent = [] + for row in connection.execute("SELECT DISTINCT gpuId FROM rocpd_op"): + gpuIdsPresent.append(row[0]) + + for gpuId in gpuIdsPresent: + # print(f"Creating counters for: {gpuId}") + + # Create the queue depth counter + depth = 0 + idle = 1 + for row in connection.execute( + 'select * from (select rocpd_api.start/1000.0 as ts, "1" from rocpd_api_ops INNER JOIN rocpd_api on rocpd_api_ops.api_id = rocpd_api.id INNER JOIN rocpd_op on rocpd_api_ops.op_id = rocpd_op.id AND rocpd_op.gpuId = %s %s UNION ALL select rocpd_op.end/1000.0, "-1" from rocpd_api_ops INNER JOIN rocpd_api on rocpd_api_ops.api_id = rocpd_api.id INNER JOIN rocpd_op on rocpd_api_ops.op_id = rocpd_op.id AND rocpd_op.gpuId = %s %s) order by ts' + % (gpuId, rangeStringOp, gpuId, rangeStringOp) + ): + try: + if idle and int(row[1]) > 0: + idle = 0 + outfile.write( + ',{"pid":"%s","name":"Idle","ph":"C","ts":%s,"args":{"idle":%s}}\n' + % (gpuId, row[0], idle) + ) + if depth == 1 and int(row[1]) < 0: + idle = 1 + outfile.write( + ',{"pid":"%s","name":"Idle","ph":"C","ts":%s,"args":{"idle":%s}}\n' + % (gpuId, row[0], idle) + ) + depth = depth + int(row[1]) + outfile.write( + ',{"pid":"%s","name":"QueueDepth","ph":"C","ts":%s,"args":{"depth":%s}}\n' + % (gpuId, row[0], depth) + ) + except ValueError: + outfile.write("") + if T_end > 0: + outfile.write( + ',{"pid":"%s","name":"Idle","ph":"C","ts":%s,"args":{"idle":%s}}\n' + % (gpuId, T_end, idle) + ) + outfile.write( + ',{"pid":"%s","name":"QueueDepth","ph":"C","ts":%s,"args":{"depth":%s}}\n' + % (gpuId, T_end, depth) + ) + + # Create SMI counters + try: + for row in connection.execute( + "select deviceId, monitorType, start/1000.0, value from rocpd_monitor %s" + % (rangeStringMonitor) + ): + outfile.write( + ',{"pid":"%s","name":"%s","ph":"C","ts":%s,"args":{"%s":%s}}\n' + % (row[0], row[1], row[2], row[1], row[3]) + ) + # Output the endpoints of the last range + for row in connection.execute( + "select distinct deviceId, monitorType, max(end)/1000.0, value from rocpd_monitor %s group by deviceId, monitorType" + % (rangeStringMonitor) + ): + outfile.write( + ',{"pid":"%s","name":"%s","ph":"C","ts":%s,"args":{"%s":%s}}\n' + % (row[0], row[1], row[2], row[1], row[3]) + ) + except: + print("Did not find SMI data") + + # Create the (global) memory counter + """ + sizes = {} # address -> size + totalSize = 0 + exp = re.compile("^ptr\((.*)\)\s+size\((.*)\)$") + exp2 = re.compile("^ptr\((.*)\)$") + for row in connection.execute("SELECT rocpd_api.end/1000.0 as ts, B.string, '1' FROM rocpd_api INNER JOIN rocpd_string A ON A.id=rocpd_api.apiName_id INNER JOIN rocpd_string B ON B.id=rocpd_api.args_id WHERE A.string='hipFree' UNION ALL SELECT rocpd_api.start/1000.0, B.string, '0' FROM rocpd_api INNER JOIN rocpd_string A ON A.id=rocpd_api.apiName_id INNER JOIN rocpd_string B ON B.id=rocpd_api.args_id WHERE A.string='hipMalloc' ORDER BY ts asc"): + try: + if row[2] == '0': #malloc + m = exp.match(row[1]) + if m: + size = int(m.group(2), 16) + totalSize = totalSize + size + sizes[m.group(1)] = size + outfile.write(',{"pid":"0","name":"Allocated Memory","ph":"C","ts":%s,"args":{"depth":%s}}\n'%(row[0],totalSize)) + else: #free + m = exp2.match(row[1]) + if m: + try: # Sometimes free addresses are not valid or listed + size = sizes[m.group(1)] + sizes[m.group(1)] = 0 + totalSize = totalSize - size; + outfile.write(',{"pid":"0","name":"Allocated Memory","ph":"C","ts":%s,"args":{"depth":%s}}\n'%(row[0],totalSize)) + except KeyError: + pass + except ValueError: + outfile.write("") + if T_end > 0: + outfile.write(',{"pid":"0","name":"Allocated Memory","ph":"C","ts":%s,"args":{"depth":%s}}\n'%(T_end,totalSize)) + """ + + # Create "faux calling stack frame" on gpu ops traceS + stacks = {} # Call stacks built from UserMarker entres. Key is 'pid,tid' + currentFrame = {} # "Current GPU frame" (id, name, start, end). Key is 'pid,tid' + + class GpuFrame: + def __init__(self): + self.id = 0 + self.name = "" + self.start = 0 + self.end = 0 + self.gpus = [] + self.totalOps = 0 + + # FIXME: include 'start' (in ns) so we can ORDER BY it and break ties? + for row in connection.execute( + "SELECT '0', start/1000.0, pid, tid, B.string as label, '','','', '' from rocpd_api INNER JOIN rocpd_string A on A.id = rocpd_api.apiName_id AND A.string = 'UserMarker' INNER JOIN rocpd_string B on B.id = rocpd_api.args_id AND rocpd_api.start/1000.0 != rocpd_api.end/1000.0 %s UNION ALL SELECT '1', end/1000.0, pid, tid, B.string as label, '','','', '' from rocpd_api INNER JOIN rocpd_string A on A.id = rocpd_api.apiName_id AND A.string = 'UserMarker' INNER JOIN rocpd_string B on B.id = rocpd_api.args_id AND rocpd_api.start/1000.0 != rocpd_api.end/1000.0 %s UNION ALL SELECT '2', rocpd_api.start/1000.0, pid, tid, '' as label, gpuId, queueId, rocpd_op.start/1000.0, rocpd_op.end/1000.0 from rocpd_api_ops INNER JOIN rocpd_api ON rocpd_api_ops.api_id = rocpd_api.id INNER JOIN rocpd_op ON rocpd_api_ops.op_id = rocpd_op.id %s ORDER BY start/1000.0 asc" + % (rangeStringApi, rangeStringApi, rangeStringApi) + ): + try: + key = (row[2], row[3]) # Key is 'pid,tid' + if row[0] == "0": # Frame start + if key not in stacks: + stacks[key] = [] + stack = stacks[key].append((row[1], row[4])) + # print(f"0: new api frame: pid_tid={key} -> stack={stacks}") + + elif row[0] == "1": # Frame end + completed = stacks[key].pop() + # print(f"1: end api frame: pid_tid={key} -> stack={stacks}") + + elif row[0] == "2": # API + Op + if key in stacks and len(stacks[key]) > 0: + frame = stacks[key][-1] + # print(f"2: Op on {frame} ({len(stacks[key])})") + gpuFrame = None + if key not in currentFrame: # First op under the current api frame + gpuFrame = GpuFrame() + gpuFrame.id = frame[0] + gpuFrame.name = frame[1] + gpuFrame.start = row[7] + gpuFrame.end = row[8] + gpuFrame.gpus.append((row[5], row[6])) + gpuFrame.totalOps = 1 + # print(f"2a: new frame: {gpuFrame.gpus} {gpuFrame.start} {gpuFrame.end} {gpuFrame.end - gpuFrame.start}") + else: + gpuFrame = currentFrame[key] + # Another op under the same frame -> union them (but only if they are butt together) + if ( + gpuFrame.id == frame[0] + and gpuFrame.name == frame[1] + and ( + abs(row[7] - gpuFrame.end) < 200 + or abs(gpuFrame.start - row[8]) < 200 + ) + ): + # if gpuFrame.id == frame[0] and gpuFrame.name == frame[1]: # Another op under the same frame -> union them + # if False: # Turn off frame joining + if row[7] < gpuFrame.start: + gpuFrame.start = row[7] + if row[8] > gpuFrame.end: + gpuFrame.end = row[8] + if (row[5], row[6]) not in gpuFrame.gpus: + gpuFrame.gpus.append((row[5], row[6])) + gpuFrame.totalOps = gpuFrame.totalOps + 1 + # print(f"2c: union frame: {gpuFrame.gpus} {gpuFrame.start} {gpuFrame.end} {gpuFrame.end - gpuFrame.start}") + + else: # This is a new frame - dump the last and make new + gpuFrame = currentFrame[key] + for dest in gpuFrame.gpus: + # print(f"2: OUTPUT: dest={dest} time={gpuFrame.start} -> {gpuFrame.end} Duration={gpuFrame.end - gpuFrame.start} TotalOps={gpuFrame.totalOps}") + outfile.write( + ',{"pid":"%s","tid":"%s","name":"%s","ts":"%s","dur":"%s","ph":"X","args":{"desc":"%s"}}\n' + % ( + dest[0], + dest[1], + gpuFrame.name.replace('"', ""), + gpuFrame.start - 1, + gpuFrame.end - gpuFrame.start + 1, + f"UserMarker frame: {gpuFrame.totalOps} ops", + ) + ) + currentFrame.pop(key) + + # make the first op under the new frame + gpuFrame = GpuFrame() + gpuFrame.id = frame[0] + gpuFrame.name = frame[1] + gpuFrame.start = row[7] + gpuFrame.end = row[8] + gpuFrame.gpus.append((row[5], row[6])) + gpuFrame.totalOps = 1 + # print(f"2b: new frame: {gpuFrame.gpus} {gpuFrame.start} {gpuFrame.end} {gpuFrame.end - gpuFrame.start}") + + currentFrame[key] = gpuFrame + + except ValueError: + outfile.write("") + + outfile.write("]\n") + + if format == "object": + outfile.write("} \n") + + outfile.close() + connection.close() diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index 2b32f6d73..e38400e3f 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -752,6 +752,25 @@ def load_image( return image, image_size +def get_image_bytes(image_file: Union[str, bytes]): + if isinstance(image_file, bytes): + return image_file + elif image_file.startswith("http://") or image_file.startswith("https://"): + timeout = int(os.getenv("REQUEST_TIMEOUT", "3")) + response = requests.get(image_file, timeout=timeout) + return response.content + elif image_file.lower().endswith(("png", "jpg", "jpeg", "webp", "gif")): + with open(image_file, "rb") as f: + return f.read() + elif image_file.startswith("data:"): + image_file = image_file.split(",")[1] + return pybase64.b64decode(image_file) + elif isinstance(image_file, str): + return pybase64.b64decode(image_file) + else: + raise NotImplementedError(f"Invalid image: {image_file}") + + def load_video(video_file: Union[str, bytes], use_gpu: bool = True): # We import decord here to avoid a strange Segmentation fault (core dumped) issue. from decord import VideoReader, cpu, gpu @@ -807,6 +826,33 @@ def load_video(video_file: Union[str, bytes], use_gpu: bool = True): os.unlink(tmp_file.name) +def encode_video(video_path, frame_count_limit=None): + # Lazy import because decord is not available on some arm platforms. + from decord import VideoReader, cpu + + if not os.path.exists(video_path): + logger.error(f"Video {video_path} does not exist") + return [] + + if frame_count_limit == 0: + return [] + + def uniform_sample(l, n): + gap = len(l) / n + idxs = [int(i * gap + gap / 2) for i in range(n)] + return [l[i] for i in idxs] + + vr = VideoReader(video_path, ctx=cpu(0)) + sample_fps = round(vr.get_avg_fps() / 1) # FPS + frame_indices = [i for i in range(0, len(vr), sample_fps)] + if frame_count_limit is not None and len(frame_indices) > frame_count_limit: + frame_indices = uniform_sample(frame_indices, frame_count_limit) + + frames = vr.get_batch(frame_indices).asnumpy() + frames = [Image.fromarray(v.astype("uint8")) for v in frames] + return frames + + def suppress_other_loggers(): warnings.filterwarnings( "ignore", category=UserWarning, message="The given NumPy array is not writable" @@ -949,6 +995,13 @@ def set_ulimit(target_soft_limit=65535): logger.warning(f"Fail to set RLIMIT_STACK: {e}") +def rank0_log(msg: str): + from sglang.srt.distributed import get_tensor_model_parallel_rank + + if get_tensor_model_parallel_rank() == 0: + logger.info(msg) + + def add_api_key_middleware(app, api_key: str): @app.middleware("http") async def authentication(request, call_next): @@ -3045,6 +3098,44 @@ def check_cuda_result(raw_output): return results +def get_physical_device_id(pytorch_device_id: int) -> int: + """ + Convert PyTorch logical device ID to physical device ID. + """ + cuda_visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES", None) + assert ( + cuda_visible_devices is not None + ), "CUDA_VISIBLE_DEVICES should be set in a scheduler" + device_list = cuda_visible_devices.split(",") + assert ( + len(device_list) == 1 + ), "CUDA_VISIBLE_DEVICES should be set to a single device in a scheduler" + return int(device_list[0]) + + +def get_device_sm_nvidia_smi(): + try: + # Run nvidia-smi command and capture output + result = subprocess.run( + ["nvidia-smi", "--query-gpu=compute_cap", "--format=csv,noheader"], + capture_output=True, + text=True, + check=True, + ) + + # Get the first line of output (assuming at least one GPU exists) + compute_cap_str = result.stdout.strip().split("\n")[0] + + # Convert string (e.g., "9.0") to tuple of integers (9, 0) + major, minor = map(int, compute_cap_str.split(".")) + return (major, minor) + + except (subprocess.CalledProcessError, FileNotFoundError, ValueError) as e: + # Handle cases where nvidia-smi isn't available or output is unexpected + print(f"Error getting compute capability: {e}") + return (0, 0) # Default/fallback value + + def numa_bind_to_node(node: int): libnuma = ctypes.CDLL("libnuma.so") if libnuma.numa_available() < 0: @@ -3061,3 +3152,33 @@ def json_list_type(value): raise argparse.ArgumentTypeError( f"Invalid JSON list: {value}. Please provide a valid JSON list." ) + + +@contextmanager +def temp_set_cuda_visible_devices(gpu_id: int): + original_cuda_visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES") + if original_cuda_visible_devices: + cuda_visible_devices = original_cuda_visible_devices.split(",") + else: + cuda_visible_devices = [] + + str_gpu_id = cuda_visible_devices[gpu_id] if cuda_visible_devices else str(gpu_id) + os.environ["CUDA_VISIBLE_DEVICES"] = str_gpu_id + yield + if original_cuda_visible_devices: + os.environ["CUDA_VISIBLE_DEVICES"] = original_cuda_visible_devices + else: + del os.environ["CUDA_VISIBLE_DEVICES"] + + +def get_extend_input_len_swa_limit( + sliding_window_size: int, chunked_prefill_size: int, page_size: int +) -> int: + # 1. a factor of 2x is because each prefill contains chunked_prefill_size tokens, + # and between prefills, we run swa_radix_cache.cache_unfinished_req(), + # so we unlock the previously locked nodes. + # 2. max is to handle the case that chunked_prefill_size is larger than sliding_window_size. + # in that case, each prefill contains chunked_prefill_size tokens, + # and we can only free out-of-sliding-window kv indices after each prefill. + # 3. page_size is because we want to have 1 token extra for generated tokens. + return page_size + 2 * max(sliding_window_size, chunked_prefill_size) diff --git a/python/sglang/srt/warmup.py b/python/sglang/srt/warmup.py index 0bed9fb94..afba03006 100644 --- a/python/sglang/srt/warmup.py +++ b/python/sglang/srt/warmup.py @@ -1,20 +1,24 @@ +from __future__ import annotations + import logging -from typing import List +from typing import TYPE_CHECKING, List import numpy as np import tqdm from sglang.srt.disaggregation.utils import FAKE_BOOTSTRAP_HOST from sglang.srt.managers.io_struct import GenerateReqInput -from sglang.srt.managers.tokenizer_manager import TokenizerManager + +if TYPE_CHECKING: + from sglang.srt.managers.tokenizer_manager import TokenizerManager logger = logging.getLogger(__file__) _warmup_registry = {} -def warmup(name: str) -> callable: - def decorator(fn: callable): +def warmup(name: str): + def decorator(fn): _warmup_registry[name] = fn return fn