From 3f6ccda35879f8a3738013a0fb0b7fcbe9516e69 Mon Sep 17 00:00:00 2001 From: singularity-s0 <12184989+singularity-s0@users.noreply.github.com> Date: Thu, 17 Aug 2023 13:39:30 +0800 Subject: [PATCH 1/4] support auto calculate total token num --- docs/ApiServerArgs.md | 3 ++- lightllm/server/api_server.py | 8 +++++- lightllm/utils/max_token_num_utils.py | 35 +++++++++++++++++++++++++++ 3 files changed, 44 insertions(+), 2 deletions(-) create mode 100644 lightllm/utils/max_token_num_utils.py diff --git a/docs/ApiServerArgs.md b/docs/ApiServerArgs.md index 4e9171e60..c1fd81b50 100644 --- a/docs/ApiServerArgs.md +++ b/docs/ApiServerArgs.md @@ -22,7 +22,8 @@ tokenizer load mode, can be "slow" or "auto", "slow" mode always load fast but r #### --max_total_token_num -default is 6000, +default is automatically calculated. if you run into OOM error, try setting manually. + the total token num the gpu and model can support, a sample about how to set this arg: gpu: use 2 A100 80G, (--tp 2) model: llama-7b, diff --git a/lightllm/server/api_server.py b/lightllm/server/api_server.py index 3cf95d2ec..79b8f18cd 100644 --- a/lightllm/server/api_server.py +++ b/lightllm/server/api_server.py @@ -35,6 +35,7 @@ from .router.manager import start_router_process from lightllm.utils.net_utils import alloc_can_use_network_port +from lightllm.utils.max_token_num_utils import calc_max_total_token_num from lightllm.common.configs.config import setting TIMEOUT_KEEP_ALIVE = 5 # seconds. @@ -123,7 +124,7 @@ def main(): parser.add_argument("--tokenizer_mode", type=str, default="slow", help="""tokenizer load mode, can be slow or auto, slow mode load fast but run slow, slow mode is good for debug and test, when you want to get best performance, try auto mode""") - parser.add_argument("--max_total_token_num", type=int, default=6000, + parser.add_argument("--max_total_token_num", type=int, default=None, help="the total token nums the gpu and model can support, equals = max_batch * (input_len + output_len)") parser.add_argument("--batch_max_tokens", type=int, default=None, help="max tokens num for new cat batch, it control prefill batch size to Preventing OOM") @@ -153,6 +154,11 @@ def main(): setting['max_req_total_len'] = args.max_req_total_len setting['nccl_port'] = args.nccl_port + if args.max_total_token_num is None: + max_total_token_num = calc_max_total_token_num(args.tp, args.model_dir) + print("Automatically setting max_total_token_num to", max_total_token_num) + args.max_total_token_num = max_total_token_num + if args.batch_max_tokens is None: batch_max_tokens = int(1 / 6 * args.max_total_token_num) batch_max_tokens = max(batch_max_tokens, args.max_req_total_len) diff --git a/lightllm/utils/max_token_num_utils.py b/lightllm/utils/max_token_num_utils.py new file mode 100644 index 000000000..1fa42e80b --- /dev/null +++ b/lightllm/utils/max_token_num_utils.py @@ -0,0 +1,35 @@ +import torch +import os + +def get_total_free_gpu_memory(tp): + """ + Returns the total amount of free memory available on all GPUs, in Gigabytes. + """ + devices = min(tp, torch.cuda.device_count()) + total_free = 0 + for i in range(devices): + total_free += torch.cuda.mem_get_info(i)[0] + total_free = total_free / (1024 ** 3) + return total_free + +def get_total_weight_size(weight_dir): + """ + Returns the total size of all parameters in the model, in Gigabytes. + """ + total_size = 0 + files = os.listdir(weight_dir) + candidate_files = list(filter(lambda x : x.endswith('.safetensors'), files)) + if len(candidate_files) == 0: + candidate_files = list(filter(lambda x : x.endswith('.bin'), files)) + assert len(candidate_files) != 0, "can only support pytorch tensor and safetensors format for weights." + for file in candidate_files: + total_size += os.path.getsize(os.path.join(weight_dir, file)) + total_size = total_size / (1024 ** 3) + return total_size + +def calc_max_total_token_num(tp, weight_dir, mem_fill_rate=0.8, kv_cache_size=0.000488281): + """ + Calculate the max total token num that can be supported by the model. + """ + max_token_num = (get_total_free_gpu_memory(tp)-get_total_weight_size(weight_dir)) * mem_fill_rate / kv_cache_size + return int(max_token_num) \ No newline at end of file From e68926ad7ba1d009a2e9bebf400014f9ac9f54dd Mon Sep 17 00:00:00 2001 From: singularity-s0 <12184989+singularity-s0@users.noreply.github.com> Date: Thu, 17 Aug 2023 15:12:35 +0800 Subject: [PATCH 2/4] fix cuda multiprocessing error --- lightllm/utils/max_token_num_utils.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/lightllm/utils/max_token_num_utils.py b/lightllm/utils/max_token_num_utils.py index 1fa42e80b..6842d7b78 100644 --- a/lightllm/utils/max_token_num_utils.py +++ b/lightllm/utils/max_token_num_utils.py @@ -1,4 +1,5 @@ import torch +torch.multiprocessing.set_start_method('spawn', force=True) # Fork start method will cause CUDA re-initialization error import os def get_total_free_gpu_memory(tp): @@ -10,6 +11,7 @@ def get_total_free_gpu_memory(tp): for i in range(devices): total_free += torch.cuda.mem_get_info(i)[0] total_free = total_free / (1024 ** 3) + torch.cuda.close() return total_free def get_total_weight_size(weight_dir): From a9359d958297ddaea87ce5d99a0b371e055b4f1c Mon Sep 17 00:00:00 2001 From: singularity-s0 <12184989+singularity-s0@users.noreply.github.com> Date: Thu, 17 Aug 2023 15:13:52 +0800 Subject: [PATCH 3/4] bug fix --- lightllm/utils/max_token_num_utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/lightllm/utils/max_token_num_utils.py b/lightllm/utils/max_token_num_utils.py index 6842d7b78..4f39c7295 100644 --- a/lightllm/utils/max_token_num_utils.py +++ b/lightllm/utils/max_token_num_utils.py @@ -11,7 +11,6 @@ def get_total_free_gpu_memory(tp): for i in range(devices): total_free += torch.cuda.mem_get_info(i)[0] total_free = total_free / (1024 ** 3) - torch.cuda.close() return total_free def get_total_weight_size(weight_dir): From a1ef20968d79b45a33e08fba609c64979f95885c Mon Sep 17 00:00:00 2001 From: singularity-s0 <12184989+singularity-s0@users.noreply.github.com> Date: Mon, 21 Aug 2023 16:44:05 +0800 Subject: [PATCH 4/4] update kv_cache_size calculation --- lightllm/utils/max_token_num_utils.py | 25 ++++++++++++++++++++++++- 1 file changed, 24 insertions(+), 1 deletion(-) diff --git a/lightllm/utils/max_token_num_utils.py b/lightllm/utils/max_token_num_utils.py index 4f39c7295..5af90521c 100644 --- a/lightllm/utils/max_token_num_utils.py +++ b/lightllm/utils/max_token_num_utils.py @@ -1,6 +1,7 @@ import torch torch.multiprocessing.set_start_method('spawn', force=True) # Fork start method will cause CUDA re-initialization error import os +import json def get_total_free_gpu_memory(tp): """ @@ -28,9 +29,31 @@ def get_total_weight_size(weight_dir): total_size = total_size / (1024 ** 3) return total_size -def calc_max_total_token_num(tp, weight_dir, mem_fill_rate=0.8, kv_cache_size=0.000488281): +def get_kv_cache_size(model_dir): + """ + Returns the size of the kv cache for a single token, in Gigabytes. + """ + # Read from config.json + config_path = os.path.join(model_dir, 'config.json') + assert os.path.exists(config_path), "config.json not found in model directory." + try: + with open(config_path, 'r') as f: + config = json.load(f) + hidden_size = config['hidden_size'] + layer_num = config['num_hidden_layers'] + num_attention_heads = config['num_attention_heads'] + num_key_value_heads = config.get('num_key_value_heads', num_attention_heads) # Models may not be using GQA + dtype = config.get('torch_dtype', 'float16') # TODO: dtype may not be specified in config.json, should we load weights to check? + except: + raise Exception("Error reading config.json when trying to determine max_total_token_num. Please manually specify max_total_token_num in startup arguments.") + dtype_size = torch.empty(0, dtype=getattr(torch, dtype)).element_size() + kv_cache_size = hidden_size * dtype_size * 2 * layer_num / num_attention_heads * num_key_value_heads / (1024 ** 3) + return kv_cache_size + +def calc_max_total_token_num(tp, weight_dir, mem_fill_rate=0.8): """ Calculate the max total token num that can be supported by the model. """ + kv_cache_size = get_kv_cache_size(weight_dir) max_token_num = (get_total_free_gpu_memory(tp)-get_total_weight_size(weight_dir)) * mem_fill_rate / kv_cache_size return int(max_token_num) \ No newline at end of file