Support multiple paths for embeddings.

This commit is contained in:
comfyanonymous
2023-03-18 03:08:43 -04:00
parent 51d6427ddf
commit 50099bcd96
3 changed files with 26 additions and 18 deletions

View File

@@ -168,19 +168,28 @@ def unescape_important(text):
return text
def load_embed(embedding_name, embedding_directory):
embed_path = os.path.join(embedding_directory, embedding_name)
if not os.path.isfile(embed_path):
extensions = ['.safetensors', '.pt', '.bin']
valid_file = None
for x in extensions:
t = embed_path + x
if os.path.isfile(t):
valid_file = t
break
if valid_file is None:
return None
if isinstance(embedding_directory, str):
embedding_directory = [embedding_directory]
valid_file = None
for embed_dir in embedding_directory:
embed_path = os.path.join(embed_dir, embedding_name)
if not os.path.isfile(embed_path):
extensions = ['.safetensors', '.pt', '.bin']
for x in extensions:
t = embed_path + x
if os.path.isfile(t):
valid_file = t
break
else:
embed_path = valid_file
valid_file = embed_path
if valid_file is not None:
break
if valid_file is None:
return None
embed_path = valid_file
if embed_path.lower().endswith(".safetensors"):
import safetensors.torch