From eb442494d11fd21d06135c4fa6bc5a696fecfe4d Mon Sep 17 00:00:00 2001 From: John Smith Date: Sat, 22 Apr 2023 16:35:18 +0800 Subject: [PATCH] optimize mem usage --- model_attn_mlp_patch.py | 29 +++++++++++++++++++++++++++++ 1 file changed, 29 insertions(+) diff --git a/model_attn_mlp_patch.py b/model_attn_mlp_patch.py index 63e80cb..d2b0bc6 100644 --- a/model_attn_mlp_patch.py +++ b/model_attn_mlp_patch.py @@ -85,10 +85,27 @@ def make_quant_attn(model): v_proj = m.v_proj qweights = torch.cat([q_proj.qweight, k_proj.qweight, v_proj.qweight], dim=1) + del q_proj.qweight + del k_proj.qweight + del v_proj.qweight qzeros = torch.cat([q_proj.qzeros, k_proj.qzeros, v_proj.qzeros], dim=1) + del q_proj.qzeros + del k_proj.qzeros + del v_proj.qzeros scales = torch.cat([q_proj.scales, k_proj.scales, v_proj.scales], dim=1) + del q_proj.scales + del k_proj.scales + del v_proj.scales g_idx = torch.cat([q_proj.g_idx, k_proj.g_idx, v_proj.g_idx], dim=0) + del q_proj.g_idx + del k_proj.g_idx + del v_proj.g_idx bias = torch.cat([q_proj.bias, k_proj.bias, v_proj.bias], dim=0) if q_proj.bias is not None else None + if q_proj.bias is not None: + del q_proj.bias + del k_proj.bias + del v_proj.bias + torch.cuda.empty_cache() qkv_layer = Autograd4bitQuantLinear(q_proj.in_features, q_proj.out_features + k_proj.out_features + v_proj.out_features, @@ -124,10 +141,22 @@ class QuantLlamaMLP(nn.Module): up_proj = old_module.up_proj qweights = torch.cat([gate_proj.qweight, up_proj.qweight], dim=1) + del gate_proj.qweight + del up_proj.qweight qzeros = torch.cat([gate_proj.qzeros, up_proj.qzeros], dim=1) + del gate_proj.qzeros + del up_proj.qzeros scales = torch.cat([gate_proj.scales, up_proj.scales], dim=1) + del gate_proj.scales + del up_proj.scales g_idx = torch.cat([gate_proj.g_idx, up_proj.g_idx], dim=0) + del gate_proj.g_idx + del up_proj.g_idx bias = torch.cat([gate_proj.bias, up_proj.bias], dim=0) if gate_proj.bias is not None else None + if gate_proj.bias is not None: + del gate_proj.bias + del up_proj.bias + torch.cuda.empty_cache() self.gate_up_proj = Autograd4bitQuantLinear(gate_proj.in_features, gate_proj.out_features + up_proj.out_features,