Merge T2IAdapterLoader and ControlNetLoader.

Workflows will be auto updated.
This commit is contained in:
comfyanonymous
2023-03-17 18:17:59 -04:00
parent e1a9e26968
commit 2e73367f45
6 changed files with 14 additions and 44 deletions

View File

@@ -527,8 +527,10 @@ def load_controlnet(ckpt_path, model=None):
elif key in controlnet_data:
pass
else:
print("error checkpoint does not contain controlnet data", ckpt_path)
return None
net = load_t2i_adapter(controlnet_data)
if net is None:
print("error checkpoint does not contain controlnet or t2i adapter data", ckpt_path)
return net
context_dim = controlnet_data[key].shape[1]
@@ -682,15 +684,16 @@ class T2IAdapter:
out += self.previous_controlnet.get_control_models()
return out
def load_t2i_adapter(ckpt_path, model=None):
t2i_data = load_torch_file(ckpt_path)
def load_t2i_adapter(t2i_data):
keys = t2i_data.keys()
if "body.0.in_conv.weight" in keys:
cin = t2i_data['body.0.in_conv.weight'].shape[1]
model_ad = adapter.Adapter_light(cin=cin, channels=[320, 640, 1280, 1280], nums_rb=4)
else:
elif 'conv_in.weight' in keys:
cin = t2i_data['conv_in.weight'].shape[1]
model_ad = adapter.Adapter(cin=cin, channels=[320, 640, 1280, 1280][:4], nums_rb=2, ksize=1, sk=True, use_conv=False)
else:
return None
model_ad.load_state_dict(t2i_data)
return T2IAdapter(model_ad, cin // 64)