Smarter memory management.

Try to keep models on the vram when possible.

Better lowvram mode for controlnets.
This commit is contained in:
comfyanonymous
2023-08-17 01:06:34 -04:00
parent 2c97c30256
commit 89a0767abf
6 changed files with 230 additions and 168 deletions

View File

@@ -244,30 +244,15 @@ class Gligen(nn.Module):
self.position_net = position_net
self.key_dim = key_dim
self.max_objs = 30
self.lowvram = False
self.current_device = torch.device("cpu")
def _set_position(self, boxes, masks, positive_embeddings):
if self.lowvram == True:
self.position_net.to(boxes.device)
objs = self.position_net(boxes, masks, positive_embeddings)
if self.lowvram == True:
self.position_net.cpu()
def func_lowvram(x, extra_options):
key = extra_options["transformer_index"]
module = self.module_list[key]
module.to(x.device)
r = module(x, objs)
module.cpu()
return r
return func_lowvram
else:
def func(x, extra_options):
key = extra_options["transformer_index"]
module = self.module_list[key]
return module(x, objs)
return func
def func(x, extra_options):
key = extra_options["transformer_index"]
module = self.module_list[key]
return module(x, objs)
return func
def set_position(self, latent_image_shape, position_params, device):
batch, c, h, w = latent_image_shape
@@ -312,14 +297,6 @@ class Gligen(nn.Module):
masks.to(device),
conds.to(device))
def set_lowvram(self, value=True):
self.lowvram = value
def cleanup(self):
self.lowvram = False
def get_models(self):
return [self]
def load_gligen(sd):
sd_k = sd.keys()