mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2025-01-06 14:35:25 +08:00
Gradient clipping in train tab
This commit is contained in:
parent
737eb28fac
commit
2a25729623
@ -327,7 +327,7 @@ def report_statistics(loss_info:dict):
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, create_image_every, save_hypernetwork_every, template_file, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
|
def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, clip_grad_mode, clip_grad_value, create_image_every, save_hypernetwork_every, template_file, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
|
||||||
# images allows training previews to have infotext. Importing it at the top causes a circular import problem.
|
# images allows training previews to have infotext. Importing it at the top causes a circular import problem.
|
||||||
from modules import images
|
from modules import images
|
||||||
|
|
||||||
@ -384,6 +384,9 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
|
|||||||
if ititial_step > steps:
|
if ititial_step > steps:
|
||||||
return hypernetwork, filename
|
return hypernetwork, filename
|
||||||
|
|
||||||
|
clip_grad_mode_value = clip_grad_mode == "value"
|
||||||
|
clip_grad_mode_norm = clip_grad_mode == "norm"
|
||||||
|
|
||||||
scheduler = LearnRateScheduler(learn_rate, steps, ititial_step)
|
scheduler = LearnRateScheduler(learn_rate, steps, ititial_step)
|
||||||
# if optimizer == "AdamW": or else Adam / AdamW / SGD, etc...
|
# if optimizer == "AdamW": or else Adam / AdamW / SGD, etc...
|
||||||
optimizer = torch.optim.AdamW(weights, lr=scheduler.learn_rate)
|
optimizer = torch.optim.AdamW(weights, lr=scheduler.learn_rate)
|
||||||
@ -426,6 +429,11 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
|
|||||||
steps_without_grad = 0
|
steps_without_grad = 0
|
||||||
assert steps_without_grad < 10, 'no gradient found for the trained weight after backward() for 10 steps in a row; this is a bug; training cannot continue'
|
assert steps_without_grad < 10, 'no gradient found for the trained weight after backward() for 10 steps in a row; this is a bug; training cannot continue'
|
||||||
|
|
||||||
|
if clip_grad_mode_value:
|
||||||
|
torch.nn.utils.clip_grad_value_(weights, clip_value=clip_grad_value)
|
||||||
|
elif clip_grad_mode_norm:
|
||||||
|
torch.nn.utils.clip_grad_norm_(weights, max_norm=clip_grad_value)
|
||||||
|
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
|
|
||||||
if torch.isnan(losses[hypernetwork.step % losses.shape[0]]):
|
if torch.isnan(losses[hypernetwork.step % losses.shape[0]]):
|
||||||
|
@ -1313,6 +1313,9 @@ def create_ui(wrap_gradio_gpu_call):
|
|||||||
training_width = gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512)
|
training_width = gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512)
|
||||||
training_height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512)
|
training_height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512)
|
||||||
steps = gr.Number(label='Max steps', value=100000, precision=0)
|
steps = gr.Number(label='Max steps', value=100000, precision=0)
|
||||||
|
with gr.Row():
|
||||||
|
clip_grad_mode = gr.Dropdown(value="disabled", label="Gradient Clipping", choices=["disabled", "value", "norm"])
|
||||||
|
clip_grad_value = gr.Number(value=1.0, show_label=False)
|
||||||
create_image_every = gr.Number(label='Save an image to log directory every N steps, 0 to disable', value=500, precision=0)
|
create_image_every = gr.Number(label='Save an image to log directory every N steps, 0 to disable', value=500, precision=0)
|
||||||
save_embedding_every = gr.Number(label='Save a copy of embedding to log directory every N steps, 0 to disable', value=500, precision=0)
|
save_embedding_every = gr.Number(label='Save a copy of embedding to log directory every N steps, 0 to disable', value=500, precision=0)
|
||||||
save_image_with_stored_embedding = gr.Checkbox(label='Save images with embedding in PNG chunks', value=True)
|
save_image_with_stored_embedding = gr.Checkbox(label='Save images with embedding in PNG chunks', value=True)
|
||||||
@ -1406,6 +1409,8 @@ def create_ui(wrap_gradio_gpu_call):
|
|||||||
training_width,
|
training_width,
|
||||||
training_height,
|
training_height,
|
||||||
steps,
|
steps,
|
||||||
|
clip_grad_mode,
|
||||||
|
clip_grad_value,
|
||||||
create_image_every,
|
create_image_every,
|
||||||
save_embedding_every,
|
save_embedding_every,
|
||||||
template_file,
|
template_file,
|
||||||
@ -1431,6 +1436,8 @@ def create_ui(wrap_gradio_gpu_call):
|
|||||||
training_width,
|
training_width,
|
||||||
training_height,
|
training_height,
|
||||||
steps,
|
steps,
|
||||||
|
clip_grad_mode,
|
||||||
|
clip_grad_value,
|
||||||
create_image_every,
|
create_image_every,
|
||||||
save_embedding_every,
|
save_embedding_every,
|
||||||
template_file,
|
template_file,
|
||||||
|
Loading…
Reference in New Issue
Block a user