Support multiple paths for embeddings.
This commit is contained in:
@@ -168,19 +168,28 @@ def unescape_important(text):
|
||||
return text
|
||||
|
||||
def load_embed(embedding_name, embedding_directory):
|
||||
embed_path = os.path.join(embedding_directory, embedding_name)
|
||||
if not os.path.isfile(embed_path):
|
||||
extensions = ['.safetensors', '.pt', '.bin']
|
||||
valid_file = None
|
||||
for x in extensions:
|
||||
t = embed_path + x
|
||||
if os.path.isfile(t):
|
||||
valid_file = t
|
||||
break
|
||||
if valid_file is None:
|
||||
return None
|
||||
if isinstance(embedding_directory, str):
|
||||
embedding_directory = [embedding_directory]
|
||||
|
||||
valid_file = None
|
||||
for embed_dir in embedding_directory:
|
||||
embed_path = os.path.join(embed_dir, embedding_name)
|
||||
if not os.path.isfile(embed_path):
|
||||
extensions = ['.safetensors', '.pt', '.bin']
|
||||
for x in extensions:
|
||||
t = embed_path + x
|
||||
if os.path.isfile(t):
|
||||
valid_file = t
|
||||
break
|
||||
else:
|
||||
embed_path = valid_file
|
||||
valid_file = embed_path
|
||||
if valid_file is not None:
|
||||
break
|
||||
|
||||
if valid_file is None:
|
||||
return None
|
||||
|
||||
embed_path = valid_file
|
||||
|
||||
if embed_path.lower().endswith(".safetensors"):
|
||||
import safetensors.torch
|
||||
|
||||
Reference in New Issue
Block a user