Merge pull request #15943 from eatmoreapple/update-lora-load

feat: lora partial update precede full update
This commit is contained in:
AUTOMATIC1111 2024-06-08 08:51:42 +03:00 committed by GitHub
commit 6de733c91e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -260,6 +260,16 @@ def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=No
loaded_networks.clear()
unavailable_networks = []
for name in names:
if name.lower() in forbidden_network_aliases and available_networks.get(name) is None:
unavailable_networks.append(name)
elif available_network_aliases.get(name) is None:
unavailable_networks.append(name)
if unavailable_networks:
update_available_networks_by_names(unavailable_networks)
networks_on_disk = [available_networks.get(name, None) if name.lower() in forbidden_network_aliases else available_network_aliases.get(name, None) for name in names]
if any(x is None for x in networks_on_disk):
list_available_networks()
@ -566,22 +576,16 @@ def network_MultiheadAttention_load_state_dict(self, *args, **kwargs):
return originals.MultiheadAttention_load_state_dict(self, *args, **kwargs)
def list_available_networks():
available_networks.clear()
available_network_aliases.clear()
forbidden_network_aliases.clear()
available_network_hash_lookup.clear()
forbidden_network_aliases.update({"none": 1, "Addams": 1})
os.makedirs(shared.cmd_opts.lora_dir, exist_ok=True)
def process_network_files(names: list[str] | None = None):
candidates = list(shared.walk_files(shared.cmd_opts.lora_dir, allowed_extensions=[".pt", ".ckpt", ".safetensors"]))
candidates += list(shared.walk_files(shared.cmd_opts.lyco_dir_backcompat, allowed_extensions=[".pt", ".ckpt", ".safetensors"]))
for filename in candidates:
if os.path.isdir(filename):
continue
name = os.path.splitext(os.path.basename(filename))[0]
# if names is provided, only load networks with names in the list
if names and name not in names:
continue
try:
entry = network.NetworkOnDisk(name, filename)
except OSError: # should catch FileNotFoundError and PermissionError etc.
@ -597,6 +601,22 @@ def list_available_networks():
available_network_aliases[entry.alias] = entry
def update_available_networks_by_names(names: list[str]):
process_network_files(names)
def list_available_networks():
available_networks.clear()
available_network_aliases.clear()
forbidden_network_aliases.clear()
available_network_hash_lookup.clear()
forbidden_network_aliases.update({"none": 1, "Addams": 1})
os.makedirs(shared.cmd_opts.lora_dir, exist_ok=True)
process_network_files()
re_network_name = re.compile(r"(.*)\s*\([0-9a-fA-F]+\)")