mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2025-01-18 14:55:09 +08:00
Merge branch 'master' into extra-networks-toggle
This commit is contained in:
commit
6f18c9b13f
2
.github/workflows/run_tests.yaml
vendored
2
.github/workflows/run_tests.yaml
vendored
@ -18,7 +18,7 @@ jobs:
|
||||
cache-dependency-path: |
|
||||
**/requirements*txt
|
||||
- name: Run tests
|
||||
run: python launch.py --tests --no-half --disable-opt-split-attention --use-cpu all --skip-torch-cuda-test
|
||||
run: python launch.py --tests test --no-half --disable-opt-split-attention --use-cpu all --skip-torch-cuda-test
|
||||
- name: Upload main app stdout-stderr
|
||||
uses: actions/upload-artifact@v3
|
||||
if: always()
|
||||
|
@ -3,7 +3,9 @@ import os
|
||||
import re
|
||||
import torch
|
||||
|
||||
from modules import shared, devices, sd_models
|
||||
from modules import shared, devices, sd_models, errors
|
||||
|
||||
metadata_tags_order = {"ss_sd_model_name": 1, "ss_resolution": 2, "ss_clip_skip": 3, "ss_num_train_images": 10, "ss_tag_frequency": 20}
|
||||
|
||||
re_digits = re.compile(r"\d+")
|
||||
re_unet_down_blocks = re.compile(r"lora_unet_down_blocks_(\d+)_attentions_(\d+)_(.+)")
|
||||
@ -43,6 +45,23 @@ class LoraOnDisk:
|
||||
def __init__(self, name, filename):
|
||||
self.name = name
|
||||
self.filename = filename
|
||||
self.metadata = {}
|
||||
|
||||
_, ext = os.path.splitext(filename)
|
||||
if ext.lower() == ".safetensors":
|
||||
try:
|
||||
self.metadata = sd_models.read_metadata_from_safetensors(filename)
|
||||
except Exception as e:
|
||||
errors.display(e, f"reading lora {filename}")
|
||||
|
||||
if self.metadata:
|
||||
m = {}
|
||||
for k, v in sorted(self.metadata.items(), key=lambda x: metadata_tags_order.get(x[0], 999)):
|
||||
m[k] = v
|
||||
|
||||
self.metadata = m
|
||||
|
||||
self.ssmd_cover_images = self.metadata.pop('ssmd_cover_images', None) # those are cover images and they are too big to display in UI as text
|
||||
|
||||
|
||||
class LoraModule:
|
||||
@ -159,6 +178,7 @@ def load_loras(names, multipliers=None):
|
||||
|
||||
|
||||
def lora_forward(module, input, res):
|
||||
input = devices.cond_cast_unet(input)
|
||||
if len(loaded_loras) == 0:
|
||||
return res
|
||||
|
||||
|
@ -23,6 +23,7 @@ class ExtraNetworksPageLora(ui_extra_networks.ExtraNetworksPage):
|
||||
"search_term": self.search_terms_from_path(lora_on_disk.filename),
|
||||
"prompt": json.dumps(f"<lora:{name}:") + " + opts.extra_networks_default_multiplier + " + json.dumps(">"),
|
||||
"local_preview": f"{path}.{shared.opts.samples_format}",
|
||||
"metadata": json.dumps(lora_on_disk.metadata, indent=4) if lora_on_disk.metadata else None,
|
||||
}
|
||||
|
||||
def allowed_directories_for_previews(self):
|
||||
|
@ -89,22 +89,15 @@ function checkBrackets(evt, textArea, counterElt) {
|
||||
function setupBracketChecking(id_prompt, id_counter){
|
||||
var textarea = gradioApp().querySelector("#" + id_prompt + " > label > textarea");
|
||||
var counter = gradioApp().getElementById(id_counter)
|
||||
|
||||
textarea.addEventListener("input", function(evt){
|
||||
checkBrackets(evt, textarea, counter)
|
||||
});
|
||||
}
|
||||
|
||||
var shadowRootLoaded = setInterval(function() {
|
||||
var shadowRoot = document.querySelector('gradio-app').shadowRoot;
|
||||
if(! shadowRoot) return false;
|
||||
|
||||
var shadowTextArea = shadowRoot.querySelectorAll('#txt2img_prompt > label > textarea');
|
||||
if(shadowTextArea.length < 1) return false;
|
||||
|
||||
clearInterval(shadowRootLoaded);
|
||||
|
||||
onUiLoaded(function(){
|
||||
setupBracketChecking('txt2img_prompt', 'txt2img_token_counter')
|
||||
setupBracketChecking('txt2img_neg_prompt', 'txt2img_negative_token_counter')
|
||||
setupBracketChecking('img2img_prompt', 'imgimg_token_counter')
|
||||
setupBracketChecking('img2img_prompt', 'img2img_token_counter')
|
||||
setupBracketChecking('img2img_neg_prompt', 'img2img_negative_token_counter')
|
||||
}, 1000);
|
||||
})
|
@ -1,4 +1,6 @@
|
||||
<div class='card' {preview_html} onclick={card_clicked}>
|
||||
<div class='card' style={style} onclick={card_clicked}>
|
||||
{metadata_button}
|
||||
|
||||
<div class='actions'>
|
||||
<div class='additional'>
|
||||
<ul>
|
||||
|
@ -635,4 +635,30 @@ SOFTWARE.
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
</pre>
|
||||
|
||||
<h2><a href="https://github.com/explosion/curated-transformers/blob/main/LICENSE">Curated transformers</a></h2>
|
||||
<small>The MPS workaround for nn.Linear on macOS 13.2.X is based on the MPS workaround for nn.Linear created by danieldk for Curated transformers</small>
|
||||
<pre>
|
||||
The MIT License (MIT)
|
||||
|
||||
Copyright (C) 2021 ExplosionAI GmbH
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in
|
||||
all copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
THE SOFTWARE.
|
||||
</pre>
|
@ -43,7 +43,7 @@ contextMenuInit = function(){
|
||||
|
||||
})
|
||||
|
||||
gradioApp().getRootNode().appendChild(contextMenu)
|
||||
gradioApp().appendChild(contextMenu)
|
||||
|
||||
let menuWidth = contextMenu.offsetWidth + 4;
|
||||
let menuHeight = contextMenu.offsetHeight + 4;
|
||||
|
@ -1,6 +1,6 @@
|
||||
function keyupEditAttention(event){
|
||||
let target = event.originalTarget || event.composedPath()[0];
|
||||
if (!target.matches("[id*='_toprow'] textarea.gr-text-input[placeholder]")) return;
|
||||
if (! target.matches("[id*='_toprow'] [id*='_prompt'] textarea")) return;
|
||||
if (! (event.metaKey || event.ctrlKey)) return;
|
||||
|
||||
let isPlus = event.key == "ArrowUp"
|
||||
|
@ -102,4 +102,78 @@ function extraNetworksSearchButton(tabs_id, event){
|
||||
|
||||
searchTextarea.value = text
|
||||
updateInput(searchTextarea)
|
||||
}
|
||||
}
|
||||
|
||||
var globalPopup = null;
|
||||
var globalPopupInner = null;
|
||||
function popup(contents){
|
||||
if(! globalPopup){
|
||||
globalPopup = document.createElement('div')
|
||||
globalPopup.onclick = function(){ globalPopup.style.display = "none"; };
|
||||
globalPopup.classList.add('global-popup');
|
||||
|
||||
var close = document.createElement('div')
|
||||
close.classList.add('global-popup-close');
|
||||
close.onclick = function(){ globalPopup.style.display = "none"; };
|
||||
close.title = "Close";
|
||||
globalPopup.appendChild(close)
|
||||
|
||||
globalPopupInner = document.createElement('div')
|
||||
globalPopupInner.onclick = function(event){ event.stopPropagation(); return false; };
|
||||
globalPopupInner.classList.add('global-popup-inner');
|
||||
globalPopup.appendChild(globalPopupInner)
|
||||
|
||||
gradioApp().appendChild(globalPopup);
|
||||
}
|
||||
|
||||
globalPopupInner.innerHTML = '';
|
||||
globalPopupInner.appendChild(contents);
|
||||
|
||||
globalPopup.style.display = "flex";
|
||||
}
|
||||
|
||||
function extraNetworksShowMetadata(text){
|
||||
elem = document.createElement('pre')
|
||||
elem.classList.add('popup-metadata');
|
||||
elem.textContent = text;
|
||||
|
||||
popup(elem);
|
||||
}
|
||||
|
||||
function requestGet(url, data, handler, errorHandler){
|
||||
var xhr = new XMLHttpRequest();
|
||||
var args = Object.keys(data).map(function(k){ return encodeURIComponent(k) + '=' + encodeURIComponent(data[k]) }).join('&')
|
||||
xhr.open("GET", url + "?" + args, true);
|
||||
|
||||
xhr.onreadystatechange = function () {
|
||||
if (xhr.readyState === 4) {
|
||||
if (xhr.status === 200) {
|
||||
try {
|
||||
var js = JSON.parse(xhr.responseText);
|
||||
handler(js)
|
||||
} catch (error) {
|
||||
console.error(error);
|
||||
errorHandler()
|
||||
}
|
||||
} else{
|
||||
errorHandler()
|
||||
}
|
||||
}
|
||||
};
|
||||
var js = JSON.stringify(data);
|
||||
xhr.send(js);
|
||||
}
|
||||
|
||||
function extraNetworksRequestMetadata(event, extraPage, cardName){
|
||||
showError = function(){ extraNetworksShowMetadata("there was an error getting metadata"); }
|
||||
|
||||
requestGet("./sd_extra_networks/metadata", {"page": extraPage, "item": cardName}, function(data){
|
||||
if(data && data.metadata){
|
||||
extraNetworksShowMetadata(data.metadata)
|
||||
} else{
|
||||
showError()
|
||||
}
|
||||
}, showError)
|
||||
|
||||
event.stopPropagation()
|
||||
}
|
||||
|
@ -18,7 +18,7 @@ titles = {
|
||||
"\u2199\ufe0f": "Read generation parameters from prompt or last generation if prompt is empty into user interface.",
|
||||
"\u{1f4c2}": "Open images output directory",
|
||||
"\u{1f4be}": "Save style",
|
||||
"\u{1f5d1}": "Clear prompt",
|
||||
"\u{1f5d1}\ufe0f": "Clear prompt",
|
||||
"\u{1f4cb}": "Apply selected styles to current prompt",
|
||||
"\u{1f4d2}": "Paste available values into the field",
|
||||
"\u{1f3b4}": "Show/hide extra networks",
|
||||
@ -39,8 +39,7 @@ titles = {
|
||||
"Inpaint at full resolution": "Upscale masked region to target resolution, do inpainting, downscale back and paste into original image",
|
||||
|
||||
"Denoising strength": "Determines how little respect the algorithm should have for image's content. At 0, nothing will change, and at 1 you'll get an unrelated image. With values below 1.0, processing will take less steps than the Sampling Steps slider specifies.",
|
||||
"Denoising strength change factor": "In loopback mode, on each loop the denoising strength is multiplied by this value. <1 means decreasing variety so your sequence will converge on a fixed picture. >1 means increasing variety so your sequence will become more and more chaotic.",
|
||||
|
||||
|
||||
"Skip": "Stop processing current image and continue processing.",
|
||||
"Interrupt": "Stop processing images and return any results accumulated so far.",
|
||||
"Save": "Write image to a directory (default - log/images) and generation parameters into csv file.",
|
||||
@ -70,8 +69,10 @@ titles = {
|
||||
"Directory name pattern": "Use following tags to define how subdirectories for images and grids are chosen: [steps], [cfg],[prompt_hash], [prompt], [prompt_no_styles], [prompt_spaces], [width], [height], [styles], [sampler], [seed], [model_hash], [model_name], [prompt_words], [date], [datetime], [datetime<Format>], [datetime<Format><Time Zone>], [job_timestamp]; leave empty for default.",
|
||||
"Max prompt words": "Set the maximum number of words to be used in the [prompt_words] option; ATTENTION: If the words are too long, they may exceed the maximum length of the file path that the system can handle",
|
||||
|
||||
"Loopback": "Process an image, use it as an input, repeat.",
|
||||
"Loops": "How many times to repeat processing an image and using it as input for the next iteration",
|
||||
"Loopback": "Performs img2img processing multiple times. Output images are used as input for the next loop.",
|
||||
"Loops": "How many times to process an image. Each output is used as the input of the next loop. If set to 1, behavior will be as if this script were not used.",
|
||||
"Final denoising strength": "The denoising strength for the final loop of each image in the batch.",
|
||||
"Denoising strength curve": "The denoising curve controls the rate of denoising strength change each loop. Aggressive: Most of the change will happen towards the start of the loops. Linear: Change will be constant through all loops. Lazy: Most of the change will happen towards the end of the loops.",
|
||||
|
||||
"Style 1": "Style to apply; styles have components for both positive and negative prompts and apply to both",
|
||||
"Style 2": "Style to apply; styles have components for both positive and negative prompts and apply to both",
|
||||
|
@ -50,7 +50,7 @@ function updateOnBackgroundChange() {
|
||||
}
|
||||
|
||||
function modalImageSwitch(offset) {
|
||||
var allgalleryButtons = gradioApp().querySelectorAll(".gallery-item.transition-all")
|
||||
var allgalleryButtons = gradioApp().querySelectorAll(".gradio-gallery .thumbnail-item")
|
||||
var galleryButtons = []
|
||||
allgalleryButtons.forEach(function(elem) {
|
||||
if (elem.parentElement.offsetParent) {
|
||||
@ -59,7 +59,7 @@ function modalImageSwitch(offset) {
|
||||
})
|
||||
|
||||
if (galleryButtons.length > 1) {
|
||||
var allcurrentButtons = gradioApp().querySelectorAll(".gallery-item.transition-all.\\!ring-2")
|
||||
var allcurrentButtons = gradioApp().querySelectorAll(".gradio-gallery .thumbnail-item.selected")
|
||||
var currentButton = null
|
||||
allcurrentButtons.forEach(function(elem) {
|
||||
if (elem.parentElement.offsetParent) {
|
||||
@ -136,37 +136,29 @@ function modalKeyHandler(event) {
|
||||
}
|
||||
}
|
||||
|
||||
function showGalleryImage() {
|
||||
setTimeout(function() {
|
||||
fullImg_preview = gradioApp().querySelectorAll('img.w-full.object-contain')
|
||||
function setupImageForLightbox(e) {
|
||||
if (e.dataset.modded)
|
||||
return;
|
||||
|
||||
if (fullImg_preview != null) {
|
||||
fullImg_preview.forEach(function function_name(e) {
|
||||
if (e.dataset.modded)
|
||||
return;
|
||||
e.dataset.modded = true;
|
||||
if(e && e.parentElement.tagName == 'DIV'){
|
||||
e.style.cursor='pointer'
|
||||
e.style.userSelect='none'
|
||||
e.dataset.modded = true;
|
||||
e.style.cursor='pointer'
|
||||
e.style.userSelect='none'
|
||||
|
||||
var isFirefox = isFirefox = navigator.userAgent.toLowerCase().indexOf('firefox') > -1
|
||||
var isFirefox = navigator.userAgent.toLowerCase().indexOf('firefox') > -1
|
||||
|
||||
// For Firefox, listening on click first switched to next image then shows the lightbox.
|
||||
// If you know how to fix this without switching to mousedown event, please.
|
||||
// For other browsers the event is click to make it possiblr to drag picture.
|
||||
var event = isFirefox ? 'mousedown' : 'click'
|
||||
// For Firefox, listening on click first switched to next image then shows the lightbox.
|
||||
// If you know how to fix this without switching to mousedown event, please.
|
||||
// For other browsers the event is click to make it possiblr to drag picture.
|
||||
var event = isFirefox ? 'mousedown' : 'click'
|
||||
|
||||
e.addEventListener(event, function (evt) {
|
||||
if(!opts.js_modal_lightbox || evt.button != 0) return;
|
||||
modalZoomSet(gradioApp().getElementById('modalImage'), opts.js_modal_lightbox_initially_zoomed)
|
||||
evt.preventDefault()
|
||||
showModal(evt)
|
||||
}, true);
|
||||
}
|
||||
});
|
||||
}
|
||||
e.addEventListener(event, function (evt) {
|
||||
if(!opts.js_modal_lightbox || evt.button != 0) return;
|
||||
|
||||
modalZoomSet(gradioApp().getElementById('modalImage'), opts.js_modal_lightbox_initially_zoomed)
|
||||
evt.preventDefault()
|
||||
showModal(evt)
|
||||
}, true);
|
||||
|
||||
}, 100);
|
||||
}
|
||||
|
||||
function modalZoomSet(modalImage, enable) {
|
||||
@ -199,21 +191,21 @@ function modalTileImageToggle(event) {
|
||||
}
|
||||
|
||||
function galleryImageHandler(e) {
|
||||
if (e && e.parentElement.tagName == 'BUTTON') {
|
||||
//if (e && e.parentElement.tagName == 'BUTTON') {
|
||||
e.onclick = showGalleryImage;
|
||||
}
|
||||
//}
|
||||
}
|
||||
|
||||
onUiUpdate(function() {
|
||||
fullImg_preview = gradioApp().querySelectorAll('img.w-full')
|
||||
fullImg_preview = gradioApp().querySelectorAll('.gradio-gallery > div > img')
|
||||
if (fullImg_preview != null) {
|
||||
fullImg_preview.forEach(galleryImageHandler);
|
||||
fullImg_preview.forEach(setupImageForLightbox);
|
||||
}
|
||||
updateOnBackgroundChange();
|
||||
})
|
||||
|
||||
document.addEventListener("DOMContentLoaded", function() {
|
||||
const modalFragment = document.createDocumentFragment();
|
||||
//const modalFragment = document.createDocumentFragment();
|
||||
const modal = document.createElement('div')
|
||||
modal.onclick = closeModal;
|
||||
modal.id = "lightboxModal";
|
||||
@ -277,9 +269,9 @@ document.addEventListener("DOMContentLoaded", function() {
|
||||
|
||||
modal.appendChild(modalNext)
|
||||
|
||||
gradioApp().appendChild(modal)
|
||||
|
||||
gradioApp().getRootNode().appendChild(modal)
|
||||
|
||||
document.body.appendChild(modalFragment);
|
||||
document.body.appendChild(modal);
|
||||
|
||||
});
|
||||
|
@ -15,7 +15,7 @@ onUiUpdate(function(){
|
||||
}
|
||||
}
|
||||
|
||||
const galleryPreviews = gradioApp().querySelectorAll('div[id^="tab_"][style*="display: block"] div[id$="_results"] img.h-full.w-full.overflow-hidden');
|
||||
const galleryPreviews = gradioApp().querySelectorAll('div[id^="tab_"][style*="display: block"] div[id$="_results"] .thumbnail-item > img');
|
||||
|
||||
if (galleryPreviews == null) return;
|
||||
|
||||
|
@ -1,78 +1,13 @@
|
||||
// code related to showing and updating progressbar shown as the image is being made
|
||||
|
||||
|
||||
galleries = {}
|
||||
storedGallerySelections = {}
|
||||
galleryObservers = {}
|
||||
|
||||
function rememberGallerySelection(id_gallery){
|
||||
storedGallerySelections[id_gallery] = getGallerySelectedIndex(id_gallery)
|
||||
|
||||
}
|
||||
|
||||
function getGallerySelectedIndex(id_gallery){
|
||||
let galleryButtons = gradioApp().querySelectorAll('#'+id_gallery+' .gallery-item')
|
||||
let galleryBtnSelected = gradioApp().querySelector('#'+id_gallery+' .gallery-item.\\!ring-2')
|
||||
|
||||
let currentlySelectedIndex = -1
|
||||
galleryButtons.forEach(function(v, i){ if(v==galleryBtnSelected) { currentlySelectedIndex = i } })
|
||||
|
||||
return currentlySelectedIndex
|
||||
}
|
||||
|
||||
// this is a workaround for https://github.com/gradio-app/gradio/issues/2984
|
||||
function check_gallery(id_gallery){
|
||||
let gallery = gradioApp().getElementById(id_gallery)
|
||||
// if gallery has no change, no need to setting up observer again.
|
||||
if (gallery && galleries[id_gallery] !== gallery){
|
||||
galleries[id_gallery] = gallery;
|
||||
if(galleryObservers[id_gallery]){
|
||||
galleryObservers[id_gallery].disconnect();
|
||||
}
|
||||
|
||||
storedGallerySelections[id_gallery] = -1
|
||||
|
||||
galleryObservers[id_gallery] = new MutationObserver(function (){
|
||||
let galleryButtons = gradioApp().querySelectorAll('#'+id_gallery+' .gallery-item')
|
||||
let galleryBtnSelected = gradioApp().querySelector('#'+id_gallery+' .gallery-item.\\!ring-2')
|
||||
let currentlySelectedIndex = getGallerySelectedIndex(id_gallery)
|
||||
prevSelectedIndex = storedGallerySelections[id_gallery]
|
||||
storedGallerySelections[id_gallery] = -1
|
||||
|
||||
if (prevSelectedIndex !== -1 && galleryButtons.length>prevSelectedIndex && !galleryBtnSelected) {
|
||||
// automatically re-open previously selected index (if exists)
|
||||
activeElement = gradioApp().activeElement;
|
||||
let scrollX = window.scrollX;
|
||||
let scrollY = window.scrollY;
|
||||
|
||||
galleryButtons[prevSelectedIndex].click();
|
||||
showGalleryImage();
|
||||
|
||||
// When the gallery button is clicked, it gains focus and scrolls itself into view
|
||||
// We need to scroll back to the previous position
|
||||
setTimeout(function (){
|
||||
window.scrollTo(scrollX, scrollY);
|
||||
}, 50);
|
||||
|
||||
if(activeElement){
|
||||
// i fought this for about an hour; i don't know why the focus is lost or why this helps recover it
|
||||
// if someone has a better solution please by all means
|
||||
setTimeout(function (){
|
||||
activeElement.focus({
|
||||
preventScroll: true // Refocus the element that was focused before the gallery was opened without scrolling to it
|
||||
})
|
||||
}, 1);
|
||||
}
|
||||
}
|
||||
})
|
||||
galleryObservers[id_gallery].observe( gallery, { childList:true, subtree:false })
|
||||
}
|
||||
}
|
||||
|
||||
onUiUpdate(function(){
|
||||
check_gallery('txt2img_gallery')
|
||||
check_gallery('img2img_gallery')
|
||||
})
|
||||
|
||||
function request(url, data, handler, errorHandler){
|
||||
var xhr = new XMLHttpRequest();
|
||||
var url = url;
|
||||
|
@ -86,7 +86,7 @@ function get_tab_index(tabId){
|
||||
var res = 0
|
||||
|
||||
gradioApp().getElementById(tabId).querySelector('div').querySelectorAll('button').forEach(function(button, i){
|
||||
if(button.className.indexOf('bg-white') != -1)
|
||||
if(button.className.indexOf('selected') != -1)
|
||||
res = i
|
||||
})
|
||||
|
||||
@ -255,7 +255,6 @@ onUiUpdate(function(){
|
||||
}
|
||||
|
||||
prompt.parentElement.insertBefore(counter, prompt)
|
||||
counter.classList.add("token-counter")
|
||||
prompt.parentElement.style.position = "relative"
|
||||
|
||||
promptTokecountUpdateFuncs[id] = function(){ update_token_counter(id_button); }
|
||||
|
79
launch.py
79
launch.py
@ -5,24 +5,25 @@ import sys
|
||||
import importlib.util
|
||||
import shlex
|
||||
import platform
|
||||
import argparse
|
||||
import json
|
||||
|
||||
parser = argparse.ArgumentParser(add_help=False)
|
||||
parser.add_argument("--ui-settings-file", type=str, default='config.json')
|
||||
parser.add_argument("--data-dir", type=str, default=os.path.dirname(os.path.realpath(__file__)))
|
||||
args, _ = parser.parse_known_args(sys.argv)
|
||||
from modules import cmd_args
|
||||
from modules.paths_internal import script_path, extensions_dir
|
||||
|
||||
script_path = os.path.dirname(__file__)
|
||||
data_path = os.getcwd()
|
||||
commandline_args = os.environ.get('COMMANDLINE_ARGS', "")
|
||||
sys.argv += shlex.split(commandline_args)
|
||||
|
||||
args, _ = cmd_args.parser.parse_known_args()
|
||||
|
||||
dir_repos = "repositories"
|
||||
dir_extensions = "extensions"
|
||||
python = sys.executable
|
||||
git = os.environ.get('GIT', "git")
|
||||
index_url = os.environ.get('INDEX_URL', "")
|
||||
stored_commit_hash = None
|
||||
skip_install = False
|
||||
dir_repos = "repositories"
|
||||
|
||||
if 'GRADIO_ANALYTICS_ENABLED' not in os.environ:
|
||||
os.environ['GRADIO_ANALYTICS_ENABLED'] = 'False'
|
||||
|
||||
|
||||
def check_python_version():
|
||||
@ -70,23 +71,6 @@ def commit_hash():
|
||||
return stored_commit_hash
|
||||
|
||||
|
||||
def extract_arg(args, name):
|
||||
return [x for x in args if x != name], name in args
|
||||
|
||||
|
||||
def extract_opt(args, name):
|
||||
opt = None
|
||||
is_present = False
|
||||
if name in args:
|
||||
is_present = True
|
||||
idx = args.index(name)
|
||||
del args[idx]
|
||||
if idx < len(args) and args[idx][0] != "-":
|
||||
opt = args[idx]
|
||||
del args[idx]
|
||||
return args, is_present, opt
|
||||
|
||||
|
||||
def run(command, desc=None, errdesc=None, custom_env=None, live=False):
|
||||
if desc is not None:
|
||||
print(desc)
|
||||
@ -223,15 +207,15 @@ def list_extensions(settings_file):
|
||||
|
||||
disabled_extensions = set(settings.get('disabled_extensions', []))
|
||||
|
||||
return [x for x in os.listdir(os.path.join(data_path, dir_extensions)) if x not in disabled_extensions]
|
||||
return [x for x in os.listdir(extensions_dir) if x not in disabled_extensions]
|
||||
|
||||
|
||||
def run_extensions_installers(settings_file):
|
||||
if not os.path.isdir(dir_extensions):
|
||||
if not os.path.isdir(extensions_dir):
|
||||
return
|
||||
|
||||
for dirname_extension in list_extensions(settings_file):
|
||||
run_extension_installer(os.path.join(dir_extensions, dirname_extension))
|
||||
run_extension_installer(os.path.join(extensions_dir, dirname_extension))
|
||||
|
||||
|
||||
def prepare_environment():
|
||||
@ -239,7 +223,6 @@ def prepare_environment():
|
||||
|
||||
torch_command = os.environ.get('TORCH_COMMAND', "pip install torch==1.13.1+cu117 torchvision==0.14.1+cu117 --extra-index-url https://download.pytorch.org/whl/cu117")
|
||||
requirements_file = os.environ.get('REQS_FILE', "requirements_versions.txt")
|
||||
commandline_args = os.environ.get('COMMANDLINE_ARGS', "")
|
||||
|
||||
xformers_package = os.environ.get('XFORMERS_PACKAGE', 'xformers==0.0.16rc425')
|
||||
gfpgan_package = os.environ.get('GFPGAN_PACKAGE', "git+https://github.com/TencentARC/GFPGAN.git@8d2447a2d918f8eba5a4a01463fd48e45126a379")
|
||||
@ -258,21 +241,7 @@ def prepare_environment():
|
||||
codeformer_commit_hash = os.environ.get('CODEFORMER_COMMIT_HASH', "c5b4593074ba6214284d6acd5f1719b6c5d739af")
|
||||
blip_commit_hash = os.environ.get('BLIP_COMMIT_HASH', "48211a1594f1321b00f14c9f7a5b4813144b2fb9")
|
||||
|
||||
sys.argv += shlex.split(commandline_args)
|
||||
|
||||
sys.argv, _ = extract_arg(sys.argv, '-f')
|
||||
sys.argv, update_all_extensions = extract_arg(sys.argv, '--update-all-extensions')
|
||||
sys.argv, skip_torch_cuda_test = extract_arg(sys.argv, '--skip-torch-cuda-test')
|
||||
sys.argv, skip_python_version_check = extract_arg(sys.argv, '--skip-python-version-check')
|
||||
sys.argv, reinstall_xformers = extract_arg(sys.argv, '--reinstall-xformers')
|
||||
sys.argv, reinstall_torch = extract_arg(sys.argv, '--reinstall-torch')
|
||||
sys.argv, update_check = extract_arg(sys.argv, '--update-check')
|
||||
sys.argv, run_tests, test_dir = extract_opt(sys.argv, '--tests')
|
||||
sys.argv, skip_install = extract_arg(sys.argv, '--skip-install')
|
||||
xformers = '--xformers' in sys.argv
|
||||
ngrok = '--ngrok' in sys.argv
|
||||
|
||||
if not skip_python_version_check:
|
||||
if not args.skip_python_version_check:
|
||||
check_python_version()
|
||||
|
||||
commit = commit_hash()
|
||||
@ -280,10 +249,10 @@ def prepare_environment():
|
||||
print(f"Python {sys.version}")
|
||||
print(f"Commit hash: {commit}")
|
||||
|
||||
if reinstall_torch or not is_installed("torch") or not is_installed("torchvision"):
|
||||
if args.reinstall_torch or not is_installed("torch") or not is_installed("torchvision"):
|
||||
run(f'"{python}" -m {torch_command}', "Installing torch and torchvision", "Couldn't install torch", live=True)
|
||||
|
||||
if not skip_torch_cuda_test:
|
||||
if not args.skip_torch_cuda_test:
|
||||
run_python("import torch; assert torch.cuda.is_available(), 'Torch is not able to use GPU; add --skip-torch-cuda-test to COMMANDLINE_ARGS variable to disable this check'")
|
||||
|
||||
if not is_installed("gfpgan"):
|
||||
@ -295,7 +264,7 @@ def prepare_environment():
|
||||
if not is_installed("open_clip"):
|
||||
run_pip(f"install {openclip_package}", "open_clip")
|
||||
|
||||
if (not is_installed("xformers") or reinstall_xformers) and xformers:
|
||||
if (not is_installed("xformers") or args.reinstall_xformers) and args.xformers:
|
||||
if platform.system() == "Windows":
|
||||
if platform.python_version().startswith("3.10"):
|
||||
run_pip(f"install -U -I --no-deps {xformers_package}", "xformers")
|
||||
@ -307,7 +276,7 @@ def prepare_environment():
|
||||
elif platform.system() == "Linux":
|
||||
run_pip(f"install {xformers_package}", "xformers")
|
||||
|
||||
if not is_installed("pyngrok") and ngrok:
|
||||
if not is_installed("pyngrok") and args.ngrok:
|
||||
run_pip("install pyngrok", "ngrok")
|
||||
|
||||
os.makedirs(os.path.join(script_path, dir_repos), exist_ok=True)
|
||||
@ -327,18 +296,18 @@ def prepare_environment():
|
||||
|
||||
run_extensions_installers(settings_file=args.ui_settings_file)
|
||||
|
||||
if update_check:
|
||||
if args.update_check:
|
||||
version_check(commit)
|
||||
|
||||
if update_all_extensions:
|
||||
git_pull_recursive(os.path.join(data_path, dir_extensions))
|
||||
if args.update_all_extensions:
|
||||
git_pull_recursive(extensions_dir)
|
||||
|
||||
if "--exit" in sys.argv:
|
||||
print("Exiting because of --exit argument")
|
||||
exit(0)
|
||||
|
||||
if run_tests:
|
||||
exitcode = tests(test_dir)
|
||||
if args.tests and not args.no_tests:
|
||||
exitcode = tests(args.tests)
|
||||
exit(exitcode)
|
||||
|
||||
|
||||
@ -352,6 +321,8 @@ def tests(test_dir):
|
||||
sys.argv.append("--skip-torch-cuda-test")
|
||||
if "--disable-nan-check" not in sys.argv:
|
||||
sys.argv.append("--disable-nan-check")
|
||||
if "--no-tests" not in sys.argv:
|
||||
sys.argv.append("--no-tests")
|
||||
|
||||
print(f"Launching Web UI in another process for testing with arguments: {' '.join(sys.argv[1:])}")
|
||||
|
||||
|
@ -6,8 +6,11 @@ import uvicorn
|
||||
from threading import Lock
|
||||
from io import BytesIO
|
||||
from gradio.processing_utils import decode_base64_to_file
|
||||
from fastapi import APIRouter, Depends, FastAPI, HTTPException, Request, Response
|
||||
from fastapi import APIRouter, Depends, FastAPI, Request, Response
|
||||
from fastapi.security import HTTPBasic, HTTPBasicCredentials
|
||||
from fastapi.exceptions import HTTPException
|
||||
from fastapi.responses import JSONResponse
|
||||
from fastapi.encoders import jsonable_encoder
|
||||
from secrets import compare_digest
|
||||
|
||||
import modules.shared as shared
|
||||
@ -18,7 +21,7 @@ from modules.textual_inversion.textual_inversion import create_embedding, train_
|
||||
from modules.textual_inversion.preprocess import preprocess
|
||||
from modules.hypernetworks.hypernetwork import create_hypernetwork, train_hypernetwork
|
||||
from PIL import PngImagePlugin,Image
|
||||
from modules.sd_models import checkpoints_list
|
||||
from modules.sd_models import checkpoints_list, unload_model_weights, reload_model_weights
|
||||
from modules.sd_models_config import find_checkpoint_config_near_filename
|
||||
from modules.realesrgan_model import get_realesrgan_models
|
||||
from modules import devices
|
||||
@ -90,6 +93,16 @@ def encode_pil_to_base64(image):
|
||||
return base64.b64encode(bytes_data)
|
||||
|
||||
def api_middleware(app: FastAPI):
|
||||
rich_available = True
|
||||
try:
|
||||
import anyio # importing just so it can be placed on silent list
|
||||
import starlette # importing just so it can be placed on silent list
|
||||
from rich.console import Console
|
||||
console = Console()
|
||||
except:
|
||||
import traceback
|
||||
rich_available = False
|
||||
|
||||
@app.middleware("http")
|
||||
async def log_and_time(req: Request, call_next):
|
||||
ts = time.time()
|
||||
@ -110,6 +123,36 @@ def api_middleware(app: FastAPI):
|
||||
))
|
||||
return res
|
||||
|
||||
def handle_exception(request: Request, e: Exception):
|
||||
err = {
|
||||
"error": type(e).__name__,
|
||||
"detail": vars(e).get('detail', ''),
|
||||
"body": vars(e).get('body', ''),
|
||||
"errors": str(e),
|
||||
}
|
||||
print(f"API error: {request.method}: {request.url} {err}")
|
||||
if not isinstance(e, HTTPException): # do not print backtrace on known httpexceptions
|
||||
if rich_available:
|
||||
console.print_exception(show_locals=True, max_frames=2, extra_lines=1, suppress=[anyio, starlette], word_wrap=False, width=min([console.width, 200]))
|
||||
else:
|
||||
traceback.print_exc()
|
||||
return JSONResponse(status_code=vars(e).get('status_code', 500), content=jsonable_encoder(err))
|
||||
|
||||
@app.middleware("http")
|
||||
async def exception_handling(request: Request, call_next):
|
||||
try:
|
||||
return await call_next(request)
|
||||
except Exception as e:
|
||||
return handle_exception(request, e)
|
||||
|
||||
@app.exception_handler(Exception)
|
||||
async def fastapi_exception_handler(request: Request, e: Exception):
|
||||
return handle_exception(request, e)
|
||||
|
||||
@app.exception_handler(HTTPException)
|
||||
async def http_exception_handler(request: Request, e: HTTPException):
|
||||
return handle_exception(request, e)
|
||||
|
||||
|
||||
class Api:
|
||||
def __init__(self, app: FastAPI, queue_lock: Lock):
|
||||
@ -150,6 +193,8 @@ class Api:
|
||||
self.add_api_route("/sdapi/v1/train/embedding", self.train_embedding, methods=["POST"], response_model=TrainResponse)
|
||||
self.add_api_route("/sdapi/v1/train/hypernetwork", self.train_hypernetwork, methods=["POST"], response_model=TrainResponse)
|
||||
self.add_api_route("/sdapi/v1/memory", self.get_memory, methods=["GET"], response_model=MemoryResponse)
|
||||
self.add_api_route("/sdapi/v1/unload-checkpoint", self.unloadapi, methods=["POST"])
|
||||
self.add_api_route("/sdapi/v1/reload-checkpoint", self.reloadapi, methods=["POST"])
|
||||
self.add_api_route("/sdapi/v1/scripts", self.get_scripts_list, methods=["GET"], response_model=ScriptsList)
|
||||
|
||||
def add_api_route(self, path: str, endpoint, **kwargs):
|
||||
@ -412,6 +457,16 @@ class Api:
|
||||
|
||||
return {}
|
||||
|
||||
def unloadapi(self):
|
||||
unload_model_weights()
|
||||
|
||||
return {}
|
||||
|
||||
def reloadapi(self):
|
||||
reload_model_weights()
|
||||
|
||||
return {}
|
||||
|
||||
def skip(self):
|
||||
shared.state.skip()
|
||||
|
||||
|
102
modules/cmd_args.py
Normal file
102
modules/cmd_args.py
Normal file
@ -0,0 +1,102 @@
|
||||
import argparse
|
||||
import os
|
||||
from modules.paths_internal import models_path, script_path, data_path, extensions_dir, extensions_builtin_dir, sd_default_config, sd_model_file
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument("--update-all-extensions", action='store_true', help="launch.py argument: download updates for all extensions when starting the program")
|
||||
parser.add_argument("--skip-python-version-check", action='store_true', help="launch.py argument: do not check python version")
|
||||
parser.add_argument("--skip-torch-cuda-test", action='store_true', help="launch.py argument: do not check if CUDA is able to work properly")
|
||||
parser.add_argument("--reinstall-xformers", action='store_true', help="launch.py argument: install the appropriate version of xformers even if you have some version already installed")
|
||||
parser.add_argument("--reinstall-torch", action='store_true', help="launch.py argument: install the appropriate version of torch even if you have some version already installed")
|
||||
parser.add_argument("--update-check", action='store_true', help="launch.py argument: chck for updates at startup")
|
||||
parser.add_argument("--tests", type=str, default=None, help="launch.py argument: run tests in the specified directory")
|
||||
parser.add_argument("--no-tests", action='store_true', help="launch.py argument: do not run tests even if --tests option is specified")
|
||||
parser.add_argument("--skip-install", action='store_true', help="launch.py argument: skip installation of packages")
|
||||
parser.add_argument("--data-dir", type=str, default=os.path.dirname(os.path.dirname(os.path.realpath(__file__))), help="base path where all user data is stored")
|
||||
parser.add_argument("--config", type=str, default=sd_default_config, help="path to config which constructs model",)
|
||||
parser.add_argument("--ckpt", type=str, default=sd_model_file, help="path to checkpoint of stable diffusion model; if specified, this checkpoint will be added to the list of checkpoints and loaded",)
|
||||
parser.add_argument("--ckpt-dir", type=str, default=None, help="Path to directory with stable diffusion checkpoints")
|
||||
parser.add_argument("--vae-dir", type=str, default=None, help="Path to directory with VAE files")
|
||||
parser.add_argument("--gfpgan-dir", type=str, help="GFPGAN directory", default=('./src/gfpgan' if os.path.exists('./src/gfpgan') else './GFPGAN'))
|
||||
parser.add_argument("--gfpgan-model", type=str, help="GFPGAN model file name", default=None)
|
||||
parser.add_argument("--no-half", action='store_true', help="do not switch the model to 16-bit floats")
|
||||
parser.add_argument("--no-half-vae", action='store_true', help="do not switch the VAE model to 16-bit floats")
|
||||
parser.add_argument("--no-progressbar-hiding", action='store_true', help="do not hide progressbar in gradio UI (we hide it because it slows down ML if you have hardware acceleration in browser)")
|
||||
parser.add_argument("--max-batch-count", type=int, default=16, help="maximum batch count value for the UI")
|
||||
parser.add_argument("--embeddings-dir", type=str, default=os.path.join(data_path, 'embeddings'), help="embeddings directory for textual inversion (default: embeddings)")
|
||||
parser.add_argument("--textual-inversion-templates-dir", type=str, default=os.path.join(script_path, 'textual_inversion_templates'), help="directory with textual inversion templates")
|
||||
parser.add_argument("--hypernetwork-dir", type=str, default=os.path.join(models_path, 'hypernetworks'), help="hypernetwork directory")
|
||||
parser.add_argument("--localizations-dir", type=str, default=os.path.join(script_path, 'localizations'), help="localizations directory")
|
||||
parser.add_argument("--allow-code", action='store_true', help="allow custom script execution from webui")
|
||||
parser.add_argument("--medvram", action='store_true', help="enable stable diffusion model optimizations for sacrificing a little speed for low VRM usage")
|
||||
parser.add_argument("--lowvram", action='store_true', help="enable stable diffusion model optimizations for sacrificing a lot of speed for very low VRM usage")
|
||||
parser.add_argument("--lowram", action='store_true', help="load stable diffusion checkpoint weights to VRAM instead of RAM")
|
||||
parser.add_argument("--always-batch-cond-uncond", action='store_true', help="disables cond/uncond batching that is enabled to save memory with --medvram or --lowvram")
|
||||
parser.add_argument("--unload-gfpgan", action='store_true', help="does not do anything.")
|
||||
parser.add_argument("--precision", type=str, help="evaluate at this precision", choices=["full", "autocast"], default="autocast")
|
||||
parser.add_argument("--upcast-sampling", action='store_true', help="upcast sampling. No effect with --no-half. Usually produces similar results to --no-half with better performance while using less memory.")
|
||||
parser.add_argument("--share", action='store_true', help="use share=True for gradio and make the UI accessible through their site")
|
||||
parser.add_argument("--ngrok", type=str, help="ngrok authtoken, alternative to gradio --share", default=None)
|
||||
parser.add_argument("--ngrok-region", type=str, help="The region in which ngrok should start.", default="us")
|
||||
parser.add_argument("--enable-insecure-extension-access", action='store_true', help="enable extensions tab regardless of other options")
|
||||
parser.add_argument("--codeformer-models-path", type=str, help="Path to directory with codeformer model file(s).", default=os.path.join(models_path, 'Codeformer'))
|
||||
parser.add_argument("--gfpgan-models-path", type=str, help="Path to directory with GFPGAN model file(s).", default=os.path.join(models_path, 'GFPGAN'))
|
||||
parser.add_argument("--esrgan-models-path", type=str, help="Path to directory with ESRGAN model file(s).", default=os.path.join(models_path, 'ESRGAN'))
|
||||
parser.add_argument("--bsrgan-models-path", type=str, help="Path to directory with BSRGAN model file(s).", default=os.path.join(models_path, 'BSRGAN'))
|
||||
parser.add_argument("--realesrgan-models-path", type=str, help="Path to directory with RealESRGAN model file(s).", default=os.path.join(models_path, 'RealESRGAN'))
|
||||
parser.add_argument("--clip-models-path", type=str, help="Path to directory with CLIP model file(s).", default=None)
|
||||
parser.add_argument("--xformers", action='store_true', help="enable xformers for cross attention layers")
|
||||
parser.add_argument("--force-enable-xformers", action='store_true', help="enable xformers for cross attention layers regardless of whether the checking code thinks you can run it; do not make bug reports if this fails to work")
|
||||
parser.add_argument("--xformers-flash-attention", action='store_true', help="enable xformers with Flash Attention to improve reproducibility (supported for SD2.x or variant only)")
|
||||
parser.add_argument("--deepdanbooru", action='store_true', help="does not do anything")
|
||||
parser.add_argument("--opt-split-attention", action='store_true', help="force-enables Doggettx's cross-attention layer optimization. By default, it's on for torch cuda.")
|
||||
parser.add_argument("--opt-sub-quad-attention", action='store_true', help="enable memory efficient sub-quadratic cross-attention layer optimization")
|
||||
parser.add_argument("--sub-quad-q-chunk-size", type=int, help="query chunk size for the sub-quadratic cross-attention layer optimization to use", default=1024)
|
||||
parser.add_argument("--sub-quad-kv-chunk-size", type=int, help="kv chunk size for the sub-quadratic cross-attention layer optimization to use", default=None)
|
||||
parser.add_argument("--sub-quad-chunk-threshold", type=int, help="the percentage of VRAM threshold for the sub-quadratic cross-attention layer optimization to use chunking", default=None)
|
||||
parser.add_argument("--opt-split-attention-invokeai", action='store_true', help="force-enables InvokeAI's cross-attention layer optimization. By default, it's on when cuda is unavailable.")
|
||||
parser.add_argument("--opt-split-attention-v1", action='store_true', help="enable older version of split attention optimization that does not consume all the VRAM it can find")
|
||||
parser.add_argument("--opt-sdp-attention", action='store_true', help="enable scaled dot product cross-attention layer optimization; requires PyTorch 2.*")
|
||||
parser.add_argument("--opt-sdp-no-mem-attention", action='store_true', help="enable scaled dot product cross-attention layer optimization without memory efficient attention, makes image generation deterministic; requires PyTorch 2.*")
|
||||
parser.add_argument("--disable-opt-split-attention", action='store_true', help="force-disables cross-attention layer optimization")
|
||||
parser.add_argument("--disable-nan-check", action='store_true', help="do not check if produced images/latent spaces have nans; useful for running without a checkpoint in CI")
|
||||
parser.add_argument("--use-cpu", nargs='+', help="use CPU as torch device for specified modules", default=[], type=str.lower)
|
||||
parser.add_argument("--listen", action='store_true', help="launch gradio with 0.0.0.0 as server name, allowing to respond to network requests")
|
||||
parser.add_argument("--port", type=int, help="launch gradio with given server port, you need root/admin rights for ports < 1024, defaults to 7860 if available", default=None)
|
||||
parser.add_argument("--show-negative-prompt", action='store_true', help="does not do anything", default=False)
|
||||
parser.add_argument("--ui-config-file", type=str, help="filename to use for ui configuration", default=os.path.join(data_path, 'ui-config.json'))
|
||||
parser.add_argument("--hide-ui-dir-config", action='store_true', help="hide directory configuration from webui", default=False)
|
||||
parser.add_argument("--freeze-settings", action='store_true', help="disable editing settings", default=False)
|
||||
parser.add_argument("--ui-settings-file", type=str, help="filename to use for ui settings", default=os.path.join(data_path, 'config.json'))
|
||||
parser.add_argument("--gradio-debug", action='store_true', help="launch gradio with --debug option")
|
||||
parser.add_argument("--gradio-auth", type=str, help='set gradio authentication like "username:password"; or comma-delimit multiple like "u1:p1,u2:p2,u3:p3"', default=None)
|
||||
parser.add_argument("--gradio-auth-path", type=str, help='set gradio authentication file path ex. "/path/to/auth/file" same auth format as --gradio-auth', default=None)
|
||||
parser.add_argument("--gradio-img2img-tool", type=str, help='does not do anything')
|
||||
parser.add_argument("--gradio-inpaint-tool", type=str, help="does not do anything")
|
||||
parser.add_argument("--opt-channelslast", action='store_true', help="change memory type for stable diffusion to channels last")
|
||||
parser.add_argument("--styles-file", type=str, help="filename to use for styles", default=os.path.join(data_path, 'styles.csv'))
|
||||
parser.add_argument("--autolaunch", action='store_true', help="open the webui URL in the system's default browser upon launch", default=False)
|
||||
parser.add_argument("--theme", type=str, help="launches the UI with light or dark theme", default=None)
|
||||
parser.add_argument("--use-textbox-seed", action='store_true', help="use textbox for seeds in UI (no up/down, but possible to input long seeds)", default=False)
|
||||
parser.add_argument("--disable-console-progressbars", action='store_true', help="do not output progressbars to console", default=False)
|
||||
parser.add_argument("--enable-console-prompts", action='store_true', help="print prompts to console when generating with txt2img and img2img", default=False)
|
||||
parser.add_argument('--vae-path', type=str, help='Checkpoint to use as VAE; setting this argument disables all settings related to VAE', default=None)
|
||||
parser.add_argument("--disable-safe-unpickle", action='store_true', help="disable checking pytorch models for malicious code", default=False)
|
||||
parser.add_argument("--api", action='store_true', help="use api=True to launch the API together with the webui (use --nowebui instead for only the API)")
|
||||
parser.add_argument("--api-auth", type=str, help='Set authentication for API like "username:password"; or comma-delimit multiple like "u1:p1,u2:p2,u3:p3"', default=None)
|
||||
parser.add_argument("--api-log", action='store_true', help="use api-log=True to enable logging of all API requests")
|
||||
parser.add_argument("--nowebui", action='store_true', help="use api=True to launch the API instead of the webui")
|
||||
parser.add_argument("--ui-debug-mode", action='store_true', help="Don't load model to quickly launch UI")
|
||||
parser.add_argument("--device-id", type=str, help="Select the default CUDA device to use (export CUDA_VISIBLE_DEVICES=0,1,etc might be needed before)", default=None)
|
||||
parser.add_argument("--administrator", action='store_true', help="Administrator rights", default=False)
|
||||
parser.add_argument("--cors-allow-origins", type=str, help="Allowed CORS origin(s) in the form of a comma-separated list (no spaces)", default=None)
|
||||
parser.add_argument("--cors-allow-origins-regex", type=str, help="Allowed CORS origin(s) in the form of a single regular expression", default=None)
|
||||
parser.add_argument("--tls-keyfile", type=str, help="Partially enables TLS, requires --tls-certfile to fully function", default=None)
|
||||
parser.add_argument("--tls-certfile", type=str, help="Partially enables TLS, requires --tls-keyfile to fully function", default=None)
|
||||
parser.add_argument("--server-name", type=str, help="Sets hostname of server", default=None)
|
||||
parser.add_argument("--gradio-queue", action='store_true', help="does not do anything", default=True)
|
||||
parser.add_argument("--no-gradio-queue", action='store_true', help="Disables gradio queue; causes the webpage to use http requests instead of websockets; was the defaul in earlier versions")
|
||||
parser.add_argument("--skip-version-check", action='store_true', help="Do not check versions of torch and xformers")
|
||||
parser.add_argument("--no-hashing", action='store_true', help="disable sha256 hashing of checkpoints to help loading performance", default=False)
|
||||
parser.add_argument("--no-download-sd-model", action='store_true', help="don't download SD1.5 model even if no model is found in --ckpt-dir", default=False)
|
@ -6,13 +6,12 @@ import time
|
||||
import git
|
||||
|
||||
from modules import paths, shared
|
||||
from modules.paths_internal import extensions_dir, extensions_builtin_dir
|
||||
|
||||
extensions = []
|
||||
extensions_dir = os.path.join(paths.data_path, "extensions")
|
||||
extensions_builtin_dir = os.path.join(paths.script_path, "extensions-builtin")
|
||||
|
||||
if not os.path.exists(extensions_dir):
|
||||
os.makedirs(extensions_dir)
|
||||
if not os.path.exists(paths.extensions_dir):
|
||||
os.makedirs(paths.extensions_dir)
|
||||
|
||||
def active():
|
||||
return [x for x in extensions if x.enabled]
|
||||
@ -86,11 +85,11 @@ class Extension:
|
||||
def list_extensions():
|
||||
extensions.clear()
|
||||
|
||||
if not os.path.isdir(extensions_dir):
|
||||
if not os.path.isdir(paths.extensions_dir):
|
||||
return
|
||||
|
||||
paths = []
|
||||
for dirname in [extensions_dir, extensions_builtin_dir]:
|
||||
extension_paths = []
|
||||
for dirname in [paths.extensions_dir, paths.extensions_builtin_dir]:
|
||||
if not os.path.isdir(dirname):
|
||||
return
|
||||
|
||||
@ -99,9 +98,9 @@ def list_extensions():
|
||||
if not os.path.isdir(path):
|
||||
continue
|
||||
|
||||
paths.append((extension_dirname, path, dirname == extensions_builtin_dir))
|
||||
extension_paths.append((extension_dirname, path, dirname == paths.extensions_builtin_dir))
|
||||
|
||||
for dirname, path, is_builtin in paths:
|
||||
for dirname, path, is_builtin in extension_paths:
|
||||
extension = Extension(name=dirname, path=path, enabled=dirname not in shared.opts.disabled_extensions, is_builtin=is_builtin)
|
||||
extensions.append(extension)
|
||||
|
||||
|
@ -401,9 +401,14 @@ def connect_paste(button, paste_fields, input_comp, override_settings_component,
|
||||
|
||||
button.click(
|
||||
fn=paste_func,
|
||||
_js=f"recalculate_prompts_{tabname}",
|
||||
inputs=[input_comp],
|
||||
outputs=[x[0] for x in paste_fields],
|
||||
)
|
||||
button.click(
|
||||
fn=None,
|
||||
_js=f"recalculate_prompts_{tabname}",
|
||||
inputs=[],
|
||||
outputs=[],
|
||||
)
|
||||
|
||||
|
||||
|
@ -573,6 +573,11 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i
|
||||
os.replace(temp_file_path, filename_without_extension + extension)
|
||||
|
||||
fullfn_without_extension, extension = os.path.splitext(params.filename)
|
||||
if hasattr(os, 'statvfs'):
|
||||
max_name_len = os.statvfs(path).f_namemax
|
||||
fullfn_without_extension = fullfn_without_extension[:max_name_len - max(4, len(extension))]
|
||||
params.filename = fullfn_without_extension + extension
|
||||
fullfn = params.filename
|
||||
_atomically_save_image(image, fullfn_without_extension, extension)
|
||||
|
||||
image.already_saved_as = fullfn
|
||||
@ -640,6 +645,8 @@ Steps: {json_info["steps"]}, Sampler: {sampler}, CFG scale: {json_info["scale"]}
|
||||
|
||||
|
||||
def image_data(data):
|
||||
import gradio as gr
|
||||
|
||||
try:
|
||||
image = Image.open(io.BytesIO(data))
|
||||
textinfo, _ = read_info_from_image(image)
|
||||
@ -655,7 +662,7 @@ def image_data(data):
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return '', None
|
||||
return gr.update(), None
|
||||
|
||||
|
||||
def flatten(img, bgcolor):
|
||||
|
@ -1,4 +1,5 @@
|
||||
import torch
|
||||
import platform
|
||||
from modules import paths
|
||||
from modules.sd_hijack_utils import CondFunc
|
||||
from packaging import version
|
||||
@ -32,6 +33,10 @@ if has_mps:
|
||||
# MPS fix for randn in torchsde
|
||||
CondFunc('torchsde._brownian.brownian_interval._randn', lambda _, size, dtype, device, seed: torch.randn(size, dtype=dtype, device=torch.device("cpu"), generator=torch.Generator(torch.device("cpu")).manual_seed(int(seed))).to(device), lambda _, size, dtype, device, seed: device.type == 'mps')
|
||||
|
||||
if platform.mac_ver()[0].startswith("13.2."):
|
||||
# MPS workaround for https://github.com/pytorch/pytorch/issues/95188, thanks to danieldk (https://github.com/explosion/curated-transformers/pull/124)
|
||||
CondFunc('torch.nn.functional.linear', lambda _, input, weight, bias: (torch.matmul(input, weight.t()) + bias) if bias is not None else torch.matmul(input, weight.t()), lambda _, input, weight, bias: input.numel() > 10485760)
|
||||
|
||||
if version.parse(torch.__version__) < version.parse("1.13"):
|
||||
# PyTorch 1.13 doesn't need these fixes but unfortunately is slower and has regressions that prevent training from working
|
||||
|
||||
@ -49,4 +54,6 @@ if has_mps:
|
||||
CondFunc('torch.cumsum', cumsum_fix_func, None)
|
||||
CondFunc('torch.Tensor.cumsum', cumsum_fix_func, None)
|
||||
CondFunc('torch.narrow', lambda orig_func, *args, **kwargs: orig_func(*args, **kwargs).clone(), None)
|
||||
|
||||
if version.parse(torch.__version__) == version.parse("2.0"):
|
||||
# MPS workaround for https://github.com/pytorch/pytorch/issues/96113
|
||||
CondFunc('torch.nn.functional.layer_norm', lambda orig_func, x, normalized_shape, weight, bias, eps, **kwargs: orig_func(x.float(), normalized_shape, weight.float() if weight is not None else None, bias.float() if bias is not None else bias, eps).to(x.dtype), lambda *args, **kwargs: len(args) == 6)
|
||||
|
@ -4,7 +4,6 @@ import shutil
|
||||
import importlib
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from basicsr.utils.download_util import load_file_from_url
|
||||
from modules import shared
|
||||
from modules.upscaler import Upscaler, UpscalerLanczos, UpscalerNearest, UpscalerNone
|
||||
from modules.paths import script_path, models_path
|
||||
@ -59,6 +58,7 @@ def load_models(model_path: str, model_url: str = None, command_path: str = None
|
||||
|
||||
if model_url is not None and len(output) == 0:
|
||||
if download_name is not None:
|
||||
from basicsr.utils.download_util import load_file_from_url
|
||||
dl = load_file_from_url(model_url, model_path, True, download_name)
|
||||
output.append(dl)
|
||||
else:
|
||||
|
@ -71,7 +71,7 @@ class UniPCSampler(object):
|
||||
# sampling
|
||||
C, H, W = shape
|
||||
size = (batch_size, C, H, W)
|
||||
print(f'Data shape for UniPC sampling is {size}')
|
||||
# print(f'Data shape for UniPC sampling is {size}')
|
||||
|
||||
device = self.model.betas.device
|
||||
if x_T is None:
|
||||
|
@ -1,6 +1,7 @@
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import math
|
||||
from tqdm.auto import trange
|
||||
|
||||
|
||||
class NoiseScheduleVP:
|
||||
@ -750,7 +751,7 @@ class UniPC:
|
||||
if method == 'multistep':
|
||||
assert steps >= order, "UniPC order must be < sampling steps"
|
||||
timesteps = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=steps, device=device)
|
||||
print(f"Running UniPC Sampling with {timesteps.shape[0]} timesteps, order {order}")
|
||||
#print(f"Running UniPC Sampling with {timesteps.shape[0]} timesteps, order {order}")
|
||||
assert timesteps.shape[0] - 1 == steps
|
||||
with torch.no_grad():
|
||||
vec_t = timesteps[0].expand((x.shape[0]))
|
||||
@ -766,7 +767,7 @@ class UniPC:
|
||||
self.after_update(x, model_x)
|
||||
model_prev_list.append(model_x)
|
||||
t_prev_list.append(vec_t)
|
||||
for step in range(order, steps + 1):
|
||||
for step in trange(order, steps + 1):
|
||||
vec_t = timesteps[step].expand(x.shape[0])
|
||||
if lower_order_final:
|
||||
step_order = min(order, steps + 1 - step)
|
||||
|
@ -1,16 +1,9 @@
|
||||
import argparse
|
||||
import os
|
||||
import sys
|
||||
from modules.paths_internal import models_path, script_path, data_path, extensions_dir, extensions_builtin_dir
|
||||
|
||||
import modules.safe
|
||||
|
||||
script_path = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
||||
|
||||
# Parse the --data-dir flag first so we can use it as a base for our other argument default values
|
||||
parser = argparse.ArgumentParser(add_help=False)
|
||||
parser.add_argument("--data-dir", type=str, default=os.path.dirname(os.path.dirname(os.path.realpath(__file__))), help="base path where all user data is stored",)
|
||||
cmd_opts_pre = parser.parse_known_args()[0]
|
||||
data_path = cmd_opts_pre.data_dir
|
||||
models_path = os.path.join(data_path, "models")
|
||||
|
||||
# data_path = cmd_opts_pre.data
|
||||
sys.path.insert(0, script_path)
|
||||
|
22
modules/paths_internal.py
Normal file
22
modules/paths_internal.py
Normal file
@ -0,0 +1,22 @@
|
||||
"""this module defines internal paths used by program and is safe to import before dependencies are installed in launch.py"""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
|
||||
script_path = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
||||
|
||||
sd_configs_path = os.path.join(script_path, "configs")
|
||||
sd_default_config = os.path.join(sd_configs_path, "v1-inference.yaml")
|
||||
sd_model_file = os.path.join(script_path, 'model.ckpt')
|
||||
default_sd_model_file = sd_model_file
|
||||
|
||||
# Parse the --data-dir flag first so we can use it as a base for our other argument default values
|
||||
parser_pre = argparse.ArgumentParser(add_help=False)
|
||||
parser_pre.add_argument("--data-dir", type=str, default=os.path.dirname(os.path.dirname(os.path.realpath(__file__))), help="base path where all user data is stored",)
|
||||
cmd_opts_pre = parser_pre.parse_known_args()[0]
|
||||
|
||||
data_path = cmd_opts_pre.data_dir
|
||||
|
||||
models_path = os.path.join(data_path, "models")
|
||||
extensions_dir = os.path.join(data_path, "extensions")
|
||||
extensions_builtin_dir = os.path.join(script_path, "extensions-builtin")
|
@ -583,6 +583,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
||||
if state.job_count == -1:
|
||||
state.job_count = p.n_iter
|
||||
|
||||
extra_network_data = None
|
||||
for n in range(p.n_iter):
|
||||
p.iteration = n
|
||||
|
||||
@ -688,6 +689,22 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
||||
image.info["parameters"] = text
|
||||
output_images.append(image)
|
||||
|
||||
if hasattr(p, 'mask_for_overlay') and p.mask_for_overlay:
|
||||
image_mask = p.mask_for_overlay.convert('RGB')
|
||||
image_mask_composite = Image.composite(image.convert('RGBA').convert('RGBa'), Image.new('RGBa', image.size), p.mask_for_overlay.convert('L')).convert('RGBA')
|
||||
|
||||
if opts.save_mask:
|
||||
images.save_image(image_mask, p.outpath_samples, "", seeds[i], prompts[i], opts.samples_format, info=infotext(n, i), p=p, suffix="-mask")
|
||||
|
||||
if opts.save_mask_composite:
|
||||
images.save_image(image_mask_composite, p.outpath_samples, "", seeds[i], prompts[i], opts.samples_format, info=infotext(n, i), p=p, suffix="-mask-composite")
|
||||
|
||||
if opts.return_mask:
|
||||
output_images.append(image_mask)
|
||||
|
||||
if opts.return_mask_composite:
|
||||
output_images.append(image_mask_composite)
|
||||
|
||||
del x_samples_ddim
|
||||
|
||||
devices.torch_gc()
|
||||
@ -712,7 +729,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
||||
if opts.grid_save:
|
||||
images.save_image(grid, p.outpath_grids, "grid", p.all_seeds[0], p.all_prompts[0], opts.grid_format, info=infotext(), short_filename=not opts.grid_extended_filename, p=p, grid=True)
|
||||
|
||||
if not p.disable_extra_networks:
|
||||
if not p.disable_extra_networks and extra_network_data:
|
||||
extra_networks.deactivate(p, extra_network_data)
|
||||
|
||||
devices.torch_gc()
|
||||
|
@ -239,7 +239,15 @@ def load_scripts():
|
||||
elif issubclass(script_class, scripts_postprocessing.ScriptPostprocessing):
|
||||
postprocessing_scripts_data.append(ScriptClassData(script_class, scriptfile.path, scriptfile.basedir, module))
|
||||
|
||||
for scriptfile in sorted(scripts_list):
|
||||
def orderby(basedir):
|
||||
# 1st webui, 2nd extensions-builtin, 3rd extensions
|
||||
priority = {os.path.join(paths.script_path, "extensions-builtin"):1, paths.script_path:0}
|
||||
for key in priority:
|
||||
if basedir.startswith(key):
|
||||
return priority[key]
|
||||
return 9999
|
||||
|
||||
for scriptfile in sorted(scripts_list, key=lambda x: [orderby(x.basedir), x]):
|
||||
try:
|
||||
if scriptfile.basedir != paths.script_path:
|
||||
sys.path = [scriptfile.basedir] + sys.path
|
||||
@ -513,6 +521,18 @@ def reload_scripts():
|
||||
scripts_postproc = scripts_postprocessing.ScriptPostprocessingRunner()
|
||||
|
||||
|
||||
def add_classes_to_gradio_component(comp):
|
||||
"""
|
||||
this adds gradio-* to the component for css styling (ie gradio-button to gr.Button), as well as some others
|
||||
"""
|
||||
|
||||
comp.elem_classes = ["gradio-" + comp.get_block_name(), *(comp.elem_classes or [])]
|
||||
|
||||
if getattr(comp, 'multiselect', False):
|
||||
comp.elem_classes.append('multiselect')
|
||||
|
||||
|
||||
|
||||
def IOComponent_init(self, *args, **kwargs):
|
||||
if scripts_current is not None:
|
||||
scripts_current.before_component(self, **kwargs)
|
||||
@ -521,6 +541,8 @@ def IOComponent_init(self, *args, **kwargs):
|
||||
|
||||
res = original_IOComponent_init(self, *args, **kwargs)
|
||||
|
||||
add_classes_to_gradio_component(self)
|
||||
|
||||
script_callbacks.after_component_callback(self, **kwargs)
|
||||
|
||||
if scripts_current is not None:
|
||||
|
@ -109,7 +109,7 @@ class ScriptPostprocessingRunner:
|
||||
inputs = []
|
||||
|
||||
for script in self.scripts_in_preferred_order():
|
||||
with gr.Box() as group:
|
||||
with gr.Row() as group:
|
||||
self.create_script_ui(script, inputs)
|
||||
|
||||
script.group = group
|
||||
|
@ -337,7 +337,7 @@ def xformers_attention_forward(self, x, context=None, mask=None):
|
||||
|
||||
dtype = q.dtype
|
||||
if shared.opts.upcast_attn:
|
||||
q, k = q.float(), k.float()
|
||||
q, k, v = q.float(), k.float(), v.float()
|
||||
|
||||
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=get_xformers_flash_attention_op(q, k, v))
|
||||
|
||||
@ -372,7 +372,7 @@ def scaled_dot_product_attention_forward(self, x, context=None, mask=None):
|
||||
|
||||
dtype = q.dtype
|
||||
if shared.opts.upcast_attn:
|
||||
q, k = q.float(), k.float()
|
||||
q, k, v = q.float(), k.float(), v.float()
|
||||
|
||||
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
||||
hidden_states = torch.nn.functional.scaled_dot_product_attention(
|
||||
|
@ -67,7 +67,7 @@ def hijack_ddpm_edit():
|
||||
unet_needs_upcast = lambda *args, **kwargs: devices.unet_needs_upcast
|
||||
CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.apply_model', apply_model, unet_needs_upcast)
|
||||
CondFunc('ldm.modules.diffusionmodules.openaimodel.timestep_embedding', lambda orig_func, timesteps, *args, **kwargs: orig_func(timesteps, *args, **kwargs).to(torch.float32 if timesteps.dtype == torch.int64 else devices.dtype_unet), unet_needs_upcast)
|
||||
if version.parse(torch.__version__) <= version.parse("1.13.1"):
|
||||
if version.parse(torch.__version__) <= version.parse("1.13.2") or torch.cuda.is_available():
|
||||
CondFunc('ldm.modules.diffusionmodules.util.GroupNorm32.forward', lambda orig_func, self, *args, **kwargs: orig_func(self.float(), *args, **kwargs), unet_needs_upcast)
|
||||
CondFunc('ldm.modules.attention.GEGLU.forward', lambda orig_func, self, x: orig_func(self.float(), x.float()).to(devices.dtype_unet), unet_needs_upcast)
|
||||
CondFunc('open_clip.transformer.ResidualAttentionBlock.__init__', lambda orig_func, *args, **kwargs: kwargs.update({'act_layer': GELUHijack}) and False or orig_func(*args, **kwargs), lambda _, *args, **kwargs: kwargs.get('act_layer') is None or kwargs['act_layer'] == torch.nn.GELU)
|
||||
|
@ -178,7 +178,7 @@ def select_checkpoint():
|
||||
return checkpoint_info
|
||||
|
||||
|
||||
chckpoint_dict_replacements = {
|
||||
checkpoint_dict_replacements = {
|
||||
'cond_stage_model.transformer.embeddings.': 'cond_stage_model.transformer.text_model.embeddings.',
|
||||
'cond_stage_model.transformer.encoder.': 'cond_stage_model.transformer.text_model.encoder.',
|
||||
'cond_stage_model.transformer.final_layer_norm.': 'cond_stage_model.transformer.text_model.final_layer_norm.',
|
||||
@ -186,7 +186,7 @@ chckpoint_dict_replacements = {
|
||||
|
||||
|
||||
def transform_checkpoint_dict_key(k):
|
||||
for text, replacement in chckpoint_dict_replacements.items():
|
||||
for text, replacement in checkpoint_dict_replacements.items():
|
||||
if k.startswith(text):
|
||||
k = replacement + k[len(text):]
|
||||
|
||||
@ -210,6 +210,30 @@ def get_state_dict_from_checkpoint(pl_sd):
|
||||
return pl_sd
|
||||
|
||||
|
||||
def read_metadata_from_safetensors(filename):
|
||||
import json
|
||||
|
||||
with open(filename, mode="rb") as file:
|
||||
metadata_len = file.read(8)
|
||||
metadata_len = int.from_bytes(metadata_len, "little")
|
||||
json_start = file.read(2)
|
||||
|
||||
assert metadata_len > 2 and json_start in (b'{"', b"{'"), f"{filename} is not a safetensors file"
|
||||
json_data = json_start + file.read(metadata_len-2)
|
||||
json_obj = json.loads(json_data)
|
||||
|
||||
res = {}
|
||||
for k, v in json_obj.get("__metadata__", {}).items():
|
||||
res[k] = v
|
||||
if isinstance(v, str) and v[0:1] == '{':
|
||||
try:
|
||||
res[k] = json.loads(v)
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
return res
|
||||
|
||||
|
||||
def read_state_dict(checkpoint_file, print_global_state=False, map_location=None):
|
||||
_, extension = os.path.splitext(checkpoint_file)
|
||||
if extension.lower() == ".safetensors":
|
||||
@ -470,7 +494,7 @@ def reload_model_weights(sd_model=None, info=None):
|
||||
if sd_model is None or checkpoint_config != sd_model.used_config:
|
||||
del sd_model
|
||||
checkpoints_loaded.clear()
|
||||
load_model(checkpoint_info, already_loaded_state_dict=state_dict, time_taken_to_load_state_dict=timer.records["load weights from disk"])
|
||||
load_model(checkpoint_info, already_loaded_state_dict=state_dict)
|
||||
return shared.sd_model
|
||||
|
||||
try:
|
||||
@ -493,3 +517,23 @@ def reload_model_weights(sd_model=None, info=None):
|
||||
print(f"Weights loaded in {timer.summary()}.")
|
||||
|
||||
return sd_model
|
||||
|
||||
def unload_model_weights(sd_model=None, info=None):
|
||||
from modules import lowvram, devices, sd_hijack
|
||||
timer = Timer()
|
||||
|
||||
if shared.sd_model:
|
||||
|
||||
# shared.sd_model.cond_stage_model.to(devices.cpu)
|
||||
# shared.sd_model.first_stage_model.to(devices.cpu)
|
||||
shared.sd_model.to(devices.cpu)
|
||||
sd_hijack.model_hijack.undo_hijack(shared.sd_model)
|
||||
shared.sd_model = None
|
||||
sd_model = None
|
||||
gc.collect()
|
||||
devices.torch_gc()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
print(f"Unloaded weights {timer.summary()}.")
|
||||
|
||||
return sd_model
|
@ -13,114 +13,22 @@ import modules.interrogate
|
||||
import modules.memmon
|
||||
import modules.styles
|
||||
import modules.devices as devices
|
||||
from modules import localization, extensions, script_loading, errors, ui_components, shared_items
|
||||
from modules.paths import models_path, script_path, data_path
|
||||
|
||||
from modules import localization, script_loading, errors, ui_components, shared_items, cmd_args
|
||||
from modules.paths_internal import models_path, script_path, data_path, sd_configs_path, sd_default_config, sd_model_file, default_sd_model_file, extensions_dir, extensions_builtin_dir
|
||||
|
||||
demo = None
|
||||
|
||||
sd_configs_path = os.path.join(script_path, "configs")
|
||||
sd_default_config = os.path.join(sd_configs_path, "v1-inference.yaml")
|
||||
sd_model_file = os.path.join(script_path, 'model.ckpt')
|
||||
default_sd_model_file = sd_model_file
|
||||
parser = cmd_args.parser
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--data-dir", type=str, default=os.path.dirname(os.path.dirname(os.path.realpath(__file__))), help="base path where all user data is stored",)
|
||||
parser.add_argument("--config", type=str, default=sd_default_config, help="path to config which constructs model",)
|
||||
parser.add_argument("--ckpt", type=str, default=sd_model_file, help="path to checkpoint of stable diffusion model; if specified, this checkpoint will be added to the list of checkpoints and loaded",)
|
||||
parser.add_argument("--ckpt-dir", type=str, default=None, help="Path to directory with stable diffusion checkpoints")
|
||||
parser.add_argument("--vae-dir", type=str, default=None, help="Path to directory with VAE files")
|
||||
parser.add_argument("--gfpgan-dir", type=str, help="GFPGAN directory", default=('./src/gfpgan' if os.path.exists('./src/gfpgan') else './GFPGAN'))
|
||||
parser.add_argument("--gfpgan-model", type=str, help="GFPGAN model file name", default=None)
|
||||
parser.add_argument("--no-half", action='store_true', help="do not switch the model to 16-bit floats")
|
||||
parser.add_argument("--no-half-vae", action='store_true', help="do not switch the VAE model to 16-bit floats")
|
||||
parser.add_argument("--no-progressbar-hiding", action='store_true', help="do not hide progressbar in gradio UI (we hide it because it slows down ML if you have hardware acceleration in browser)")
|
||||
parser.add_argument("--max-batch-count", type=int, default=16, help="maximum batch count value for the UI")
|
||||
parser.add_argument("--embeddings-dir", type=str, default=os.path.join(data_path, 'embeddings'), help="embeddings directory for textual inversion (default: embeddings)")
|
||||
parser.add_argument("--textual-inversion-templates-dir", type=str, default=os.path.join(script_path, 'textual_inversion_templates'), help="directory with textual inversion templates")
|
||||
parser.add_argument("--hypernetwork-dir", type=str, default=os.path.join(models_path, 'hypernetworks'), help="hypernetwork directory")
|
||||
parser.add_argument("--localizations-dir", type=str, default=os.path.join(script_path, 'localizations'), help="localizations directory")
|
||||
parser.add_argument("--allow-code", action='store_true', help="allow custom script execution from webui")
|
||||
parser.add_argument("--medvram", action='store_true', help="enable stable diffusion model optimizations for sacrificing a little speed for low VRM usage")
|
||||
parser.add_argument("--lowvram", action='store_true', help="enable stable diffusion model optimizations for sacrificing a lot of speed for very low VRM usage")
|
||||
parser.add_argument("--lowram", action='store_true', help="load stable diffusion checkpoint weights to VRAM instead of RAM")
|
||||
parser.add_argument("--always-batch-cond-uncond", action='store_true', help="disables cond/uncond batching that is enabled to save memory with --medvram or --lowvram")
|
||||
parser.add_argument("--unload-gfpgan", action='store_true', help="does not do anything.")
|
||||
parser.add_argument("--precision", type=str, help="evaluate at this precision", choices=["full", "autocast"], default="autocast")
|
||||
parser.add_argument("--upcast-sampling", action='store_true', help="upcast sampling. No effect with --no-half. Usually produces similar results to --no-half with better performance while using less memory.")
|
||||
parser.add_argument("--share", action='store_true', help="use share=True for gradio and make the UI accessible through their site")
|
||||
parser.add_argument("--ngrok", type=str, help="ngrok authtoken, alternative to gradio --share", default=None)
|
||||
parser.add_argument("--ngrok-region", type=str, help="The region in which ngrok should start.", default="us")
|
||||
parser.add_argument("--enable-insecure-extension-access", action='store_true', help="enable extensions tab regardless of other options")
|
||||
parser.add_argument("--codeformer-models-path", type=str, help="Path to directory with codeformer model file(s).", default=os.path.join(models_path, 'Codeformer'))
|
||||
parser.add_argument("--gfpgan-models-path", type=str, help="Path to directory with GFPGAN model file(s).", default=os.path.join(models_path, 'GFPGAN'))
|
||||
parser.add_argument("--esrgan-models-path", type=str, help="Path to directory with ESRGAN model file(s).", default=os.path.join(models_path, 'ESRGAN'))
|
||||
parser.add_argument("--bsrgan-models-path", type=str, help="Path to directory with BSRGAN model file(s).", default=os.path.join(models_path, 'BSRGAN'))
|
||||
parser.add_argument("--realesrgan-models-path", type=str, help="Path to directory with RealESRGAN model file(s).", default=os.path.join(models_path, 'RealESRGAN'))
|
||||
parser.add_argument("--clip-models-path", type=str, help="Path to directory with CLIP model file(s).", default=None)
|
||||
parser.add_argument("--xformers", action='store_true', help="enable xformers for cross attention layers")
|
||||
parser.add_argument("--force-enable-xformers", action='store_true', help="enable xformers for cross attention layers regardless of whether the checking code thinks you can run it; do not make bug reports if this fails to work")
|
||||
parser.add_argument("--xformers-flash-attention", action='store_true', help="enable xformers with Flash Attention to improve reproducibility (supported for SD2.x or variant only)")
|
||||
parser.add_argument("--deepdanbooru", action='store_true', help="does not do anything")
|
||||
parser.add_argument("--opt-split-attention", action='store_true', help="force-enables Doggettx's cross-attention layer optimization. By default, it's on for torch cuda.")
|
||||
parser.add_argument("--opt-sub-quad-attention", action='store_true', help="enable memory efficient sub-quadratic cross-attention layer optimization")
|
||||
parser.add_argument("--sub-quad-q-chunk-size", type=int, help="query chunk size for the sub-quadratic cross-attention layer optimization to use", default=1024)
|
||||
parser.add_argument("--sub-quad-kv-chunk-size", type=int, help="kv chunk size for the sub-quadratic cross-attention layer optimization to use", default=None)
|
||||
parser.add_argument("--sub-quad-chunk-threshold", type=int, help="the percentage of VRAM threshold for the sub-quadratic cross-attention layer optimization to use chunking", default=None)
|
||||
parser.add_argument("--opt-split-attention-invokeai", action='store_true', help="force-enables InvokeAI's cross-attention layer optimization. By default, it's on when cuda is unavailable.")
|
||||
parser.add_argument("--opt-split-attention-v1", action='store_true', help="enable older version of split attention optimization that does not consume all the VRAM it can find")
|
||||
parser.add_argument("--opt-sdp-attention", action='store_true', help="enable scaled dot product cross-attention layer optimization; requires PyTorch 2.*")
|
||||
parser.add_argument("--opt-sdp-no-mem-attention", action='store_true', help="enable scaled dot product cross-attention layer optimization without memory efficient attention, makes image generation deterministic; requires PyTorch 2.*")
|
||||
parser.add_argument("--disable-opt-split-attention", action='store_true', help="force-disables cross-attention layer optimization")
|
||||
parser.add_argument("--disable-nan-check", action='store_true', help="do not check if produced images/latent spaces have nans; useful for running without a checkpoint in CI")
|
||||
parser.add_argument("--use-cpu", nargs='+', help="use CPU as torch device for specified modules", default=[], type=str.lower)
|
||||
parser.add_argument("--listen", action='store_true', help="launch gradio with 0.0.0.0 as server name, allowing to respond to network requests")
|
||||
parser.add_argument("--port", type=int, help="launch gradio with given server port, you need root/admin rights for ports < 1024, defaults to 7860 if available", default=None)
|
||||
parser.add_argument("--show-negative-prompt", action='store_true', help="does not do anything", default=False)
|
||||
parser.add_argument("--ui-config-file", type=str, help="filename to use for ui configuration", default=os.path.join(data_path, 'ui-config.json'))
|
||||
parser.add_argument("--hide-ui-dir-config", action='store_true', help="hide directory configuration from webui", default=False)
|
||||
parser.add_argument("--freeze-settings", action='store_true', help="disable editing settings", default=False)
|
||||
parser.add_argument("--ui-settings-file", type=str, help="filename to use for ui settings", default=os.path.join(data_path, 'config.json'))
|
||||
parser.add_argument("--gradio-debug", action='store_true', help="launch gradio with --debug option")
|
||||
parser.add_argument("--gradio-auth", type=str, help='set gradio authentication like "username:password"; or comma-delimit multiple like "u1:p1,u2:p2,u3:p3"', default=None)
|
||||
parser.add_argument("--gradio-auth-path", type=str, help='set gradio authentication file path ex. "/path/to/auth/file" same auth format as --gradio-auth', default=None)
|
||||
parser.add_argument("--gradio-img2img-tool", type=str, help='does not do anything')
|
||||
parser.add_argument("--gradio-inpaint-tool", type=str, help="does not do anything")
|
||||
parser.add_argument("--opt-channelslast", action='store_true', help="change memory type for stable diffusion to channels last")
|
||||
parser.add_argument("--styles-file", type=str, help="filename to use for styles", default=os.path.join(data_path, 'styles.csv'))
|
||||
parser.add_argument("--autolaunch", action='store_true', help="open the webui URL in the system's default browser upon launch", default=False)
|
||||
parser.add_argument("--theme", type=str, help="launches the UI with light or dark theme", default=None)
|
||||
parser.add_argument("--use-textbox-seed", action='store_true', help="use textbox for seeds in UI (no up/down, but possible to input long seeds)", default=False)
|
||||
parser.add_argument("--disable-console-progressbars", action='store_true', help="do not output progressbars to console", default=False)
|
||||
parser.add_argument("--enable-console-prompts", action='store_true', help="print prompts to console when generating with txt2img and img2img", default=False)
|
||||
parser.add_argument('--vae-path', type=str, help='Checkpoint to use as VAE; setting this argument disables all settings related to VAE', default=None)
|
||||
parser.add_argument("--disable-safe-unpickle", action='store_true', help="disable checking pytorch models for malicious code", default=False)
|
||||
parser.add_argument("--api", action='store_true', help="use api=True to launch the API together with the webui (use --nowebui instead for only the API)")
|
||||
parser.add_argument("--api-auth", type=str, help='Set authentication for API like "username:password"; or comma-delimit multiple like "u1:p1,u2:p2,u3:p3"', default=None)
|
||||
parser.add_argument("--api-log", action='store_true', help="use api-log=True to enable logging of all API requests")
|
||||
parser.add_argument("--nowebui", action='store_true', help="use api=True to launch the API instead of the webui")
|
||||
parser.add_argument("--ui-debug-mode", action='store_true', help="Don't load model to quickly launch UI")
|
||||
parser.add_argument("--device-id", type=str, help="Select the default CUDA device to use (export CUDA_VISIBLE_DEVICES=0,1,etc might be needed before)", default=None)
|
||||
parser.add_argument("--administrator", action='store_true', help="Administrator rights", default=False)
|
||||
parser.add_argument("--cors-allow-origins", type=str, help="Allowed CORS origin(s) in the form of a comma-separated list (no spaces)", default=None)
|
||||
parser.add_argument("--cors-allow-origins-regex", type=str, help="Allowed CORS origin(s) in the form of a single regular expression", default=None)
|
||||
parser.add_argument("--tls-keyfile", type=str, help="Partially enables TLS, requires --tls-certfile to fully function", default=None)
|
||||
parser.add_argument("--tls-certfile", type=str, help="Partially enables TLS, requires --tls-keyfile to fully function", default=None)
|
||||
parser.add_argument("--server-name", type=str, help="Sets hostname of server", default=None)
|
||||
parser.add_argument("--gradio-queue", action='store_true', help="Uses gradio queue; experimental option; breaks restart UI button")
|
||||
parser.add_argument("--skip-version-check", action='store_true', help="Do not check versions of torch and xformers")
|
||||
parser.add_argument("--no-hashing", action='store_true', help="disable sha256 hashing of checkpoints to help loading performance", default=False)
|
||||
parser.add_argument("--no-download-sd-model", action='store_true', help="don't download SD1.5 model even if no model is found in --ckpt-dir", default=False)
|
||||
|
||||
|
||||
script_loading.preload_extensions(extensions.extensions_dir, parser)
|
||||
script_loading.preload_extensions(extensions.extensions_builtin_dir, parser)
|
||||
script_loading.preload_extensions(extensions_dir, parser)
|
||||
script_loading.preload_extensions(extensions_builtin_dir, parser)
|
||||
|
||||
if os.environ.get('IGNORE_CMD_ARGS_ERRORS', None) is None:
|
||||
cmd_opts = parser.parse_args()
|
||||
else:
|
||||
cmd_opts, _ = parser.parse_known_args()
|
||||
|
||||
|
||||
restricted_opts = {
|
||||
"samples_filename_pattern",
|
||||
"directories_filename_pattern",
|
||||
@ -332,6 +240,8 @@ options_templates.update(options_section(('saving-images', "Saving images/grids"
|
||||
"save_images_before_face_restoration": OptionInfo(False, "Save a copy of image before doing face restoration."),
|
||||
"save_images_before_highres_fix": OptionInfo(False, "Save a copy of image before applying highres fix."),
|
||||
"save_images_before_color_correction": OptionInfo(False, "Save a copy of image before applying color correction to img2img results"),
|
||||
"save_mask": OptionInfo(False, "For inpainting, save a copy of the greyscale mask"),
|
||||
"save_mask_composite": OptionInfo(False, "For inpainting, save a masked composite"),
|
||||
"jpeg_quality": OptionInfo(80, "Quality for saved jpeg images", gr.Slider, {"minimum": 1, "maximum": 100, "step": 1}),
|
||||
"webp_lossless": OptionInfo(False, "Use lossless compression for webp images"),
|
||||
"export_for_4chan": OptionInfo(True, "If the saved image file size is above the limit, or its either width or height are above the limit, save a downscaled copy as JPG"),
|
||||
@ -448,12 +358,16 @@ options_templates.update(options_section(('interrogate', "Interrogate Options"),
|
||||
options_templates.update(options_section(('extra_networks', "Extra Networks"), {
|
||||
"extra_networks_default_view": OptionInfo("cards", "Default view for Extra Networks", gr.Dropdown, {"choices": ["cards", "thumbs"]}),
|
||||
"extra_networks_default_multiplier": OptionInfo(1.0, "Multiplier for extra networks", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
|
||||
"extra_networks_card_width": OptionInfo(0, "Card width for Extra Networks (px)"),
|
||||
"extra_networks_card_height": OptionInfo(0, "Card height for Extra Networks (px)"),
|
||||
"extra_networks_add_text_separator": OptionInfo(" ", "Extra text to add before <...> when adding extra network to prompt"),
|
||||
"sd_hypernetwork": OptionInfo("None", "Add hypernetwork to prompt", gr.Dropdown, lambda: {"choices": [""] + [x for x in hypernetworks.keys()]}, refresh=reload_hypernetworks),
|
||||
}))
|
||||
|
||||
options_templates.update(options_section(('ui', "User interface"), {
|
||||
"return_grid": OptionInfo(True, "Show grid in results for web"),
|
||||
"return_mask": OptionInfo(False, "For inpainting, include the greyscale mask in results for web"),
|
||||
"return_mask_composite": OptionInfo(False, "For inpainting, include masked composite in results for web"),
|
||||
"do_not_show_images": OptionInfo(False, "Do not show any images in results for web"),
|
||||
"add_model_hash_to_info": OptionInfo(True, "Add model hash to generation information"),
|
||||
"add_model_name_to_info": OptionInfo(True, "Add model name to generation information"),
|
||||
|
@ -152,7 +152,11 @@ class EmbeddingDatabase:
|
||||
name = data.get('name', name)
|
||||
else:
|
||||
data = extract_image_data_embed(embed_image)
|
||||
name = data.get('name', name)
|
||||
if data:
|
||||
name = data.get('name', name)
|
||||
else:
|
||||
# if data is None, means this is not an embeding, just a preview image
|
||||
return
|
||||
elif ext in ['.BIN', '.PT']:
|
||||
data = torch.load(path, map_location="cpu")
|
||||
elif ext in ['.SAFETENSORS']:
|
||||
|
@ -20,7 +20,7 @@ from PIL import Image, PngImagePlugin
|
||||
from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, wrap_gradio_call
|
||||
|
||||
from modules import sd_hijack, sd_models, localization, script_callbacks, ui_extensions, deepbooru, sd_vae, extra_networks, postprocessing, ui_components, ui_common, ui_postprocessing
|
||||
from modules.ui_components import FormRow, FormGroup, ToolButton, FormHTML
|
||||
from modules.ui_components import FormRow, FormColumn, FormGroup, ToolButton, FormHTML
|
||||
from modules.paths import script_path, data_path
|
||||
|
||||
from modules.shared import opts, cmd_opts, restricted_opts
|
||||
@ -89,7 +89,7 @@ paste_symbol = '\u2199\ufe0f' # ↙
|
||||
refresh_symbol = '\U0001f504' # 🔄
|
||||
save_style_symbol = '\U0001f4be' # 💾
|
||||
apply_style_symbol = '\U0001f4cb' # 📋
|
||||
clear_prompt_symbol = '\U0001F5D1' # 🗑️
|
||||
clear_prompt_symbol = '\U0001f5d1\ufe0f' # 🗑️
|
||||
extra_networks_symbol = '\U0001F3B4' # 🎴
|
||||
switch_values_symbol = '\U000021C5' # ⇅
|
||||
|
||||
@ -179,14 +179,13 @@ def interrogate_deepbooru(image):
|
||||
|
||||
|
||||
def create_seed_inputs(target_interface):
|
||||
with FormRow(elem_id=target_interface + '_seed_row'):
|
||||
with FormRow(elem_id=target_interface + '_seed_row', variant="compact"):
|
||||
seed = (gr.Textbox if cmd_opts.use_textbox_seed else gr.Number)(label='Seed', value=-1, elem_id=target_interface + '_seed')
|
||||
seed.style(container=False)
|
||||
random_seed = gr.Button(random_symbol, elem_id=target_interface + '_random_seed')
|
||||
reuse_seed = gr.Button(reuse_symbol, elem_id=target_interface + '_reuse_seed')
|
||||
random_seed = ToolButton(random_symbol, elem_id=target_interface + '_random_seed')
|
||||
reuse_seed = ToolButton(reuse_symbol, elem_id=target_interface + '_reuse_seed')
|
||||
|
||||
with gr.Group(elem_id=target_interface + '_subseed_show_box'):
|
||||
seed_checkbox = gr.Checkbox(label='Extra', elem_id=target_interface + '_subseed_show', value=False)
|
||||
seed_checkbox = gr.Checkbox(label='Extra', elem_id=target_interface + '_subseed_show', value=False)
|
||||
|
||||
# Components to show/hide based on the 'Extra' checkbox
|
||||
seed_extras = []
|
||||
@ -195,8 +194,8 @@ def create_seed_inputs(target_interface):
|
||||
seed_extras.append(seed_extra_row_1)
|
||||
subseed = gr.Number(label='Variation seed', value=-1, elem_id=target_interface + '_subseed')
|
||||
subseed.style(container=False)
|
||||
random_subseed = gr.Button(random_symbol, elem_id=target_interface + '_random_subseed')
|
||||
reuse_subseed = gr.Button(reuse_symbol, elem_id=target_interface + '_reuse_subseed')
|
||||
random_subseed = ToolButton(random_symbol, elem_id=target_interface + '_random_subseed')
|
||||
reuse_subseed = ToolButton(reuse_symbol, elem_id=target_interface + '_reuse_subseed')
|
||||
subseed_strength = gr.Slider(label='Variation strength', value=0.0, minimum=0, maximum=1, step=0.01, elem_id=target_interface + '_subseed_strength')
|
||||
|
||||
with FormRow(visible=False) as seed_extra_row_2:
|
||||
@ -291,19 +290,19 @@ def create_toprow(is_img2img):
|
||||
with gr.Row():
|
||||
with gr.Column(scale=80):
|
||||
with gr.Row():
|
||||
negative_prompt = gr.Textbox(label="Negative prompt", elem_id=f"{id_part}_neg_prompt", show_label=False, lines=2, placeholder="Negative prompt (press Ctrl+Enter or Alt+Enter to generate)")
|
||||
negative_prompt = gr.Textbox(label="Negative prompt", elem_id=f"{id_part}_neg_prompt", show_label=False, lines=3, placeholder="Negative prompt (press Ctrl+Enter or Alt+Enter to generate)")
|
||||
|
||||
button_interrogate = None
|
||||
button_deepbooru = None
|
||||
if is_img2img:
|
||||
with gr.Column(scale=1, elem_id="interrogate_col"):
|
||||
with gr.Column(scale=1, elem_classes="interrogate-col"):
|
||||
button_interrogate = gr.Button('Interrogate\nCLIP', elem_id="interrogate")
|
||||
button_deepbooru = gr.Button('Interrogate\nDeepBooru', elem_id="deepbooru")
|
||||
|
||||
with gr.Column(scale=1, elem_id=f"{id_part}_actions_column"):
|
||||
with gr.Row(elem_id=f"{id_part}_generate_box"):
|
||||
interrupt = gr.Button('Interrupt', elem_id=f"{id_part}_interrupt")
|
||||
skip = gr.Button('Skip', elem_id=f"{id_part}_skip")
|
||||
with gr.Row(elem_id=f"{id_part}_generate_box", elem_classes="generate-box"):
|
||||
interrupt = gr.Button('Interrupt', elem_id=f"{id_part}_interrupt", elem_classes="generate-box-interrupt")
|
||||
skip = gr.Button('Skip', elem_id=f"{id_part}_skip", elem_classes="generate-box-skip")
|
||||
submit = gr.Button('Generate', elem_id=f"{id_part}_generate", variant='primary')
|
||||
|
||||
skip.click(
|
||||
@ -325,9 +324,9 @@ def create_toprow(is_img2img):
|
||||
prompt_style_apply = ToolButton(value=apply_style_symbol, elem_id=f"{id_part}_style_apply")
|
||||
save_style = ToolButton(value=save_style_symbol, elem_id=f"{id_part}_style_create")
|
||||
|
||||
token_counter = gr.HTML(value="<span></span>", elem_id=f"{id_part}_token_counter")
|
||||
token_counter = gr.HTML(value="<span>0/75</span>", elem_id=f"{id_part}_token_counter", elem_classes=["token-counter"])
|
||||
token_button = gr.Button(visible=False, elem_id=f"{id_part}_token_button")
|
||||
negative_token_counter = gr.HTML(value="<span></span>", elem_id=f"{id_part}_negative_token_counter")
|
||||
negative_token_counter = gr.HTML(value="<span>0/75</span>", elem_id=f"{id_part}_negative_token_counter", elem_classes=["token-counter"])
|
||||
negative_token_button = gr.Button(visible=False, elem_id=f"{id_part}_negative_token_button")
|
||||
|
||||
clear_prompt_button.click(
|
||||
@ -479,7 +478,9 @@ def create_ui():
|
||||
width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="txt2img_width")
|
||||
height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="txt2img_height")
|
||||
|
||||
res_switch_btn = ToolButton(value=switch_values_symbol, elem_id="txt2img_res_switch_btn")
|
||||
with gr.Column(elem_id="txt2img_dimensions_row", scale=1, elem_classes="dimensions-tools"):
|
||||
res_switch_btn = ToolButton(value=switch_values_symbol, elem_id="txt2img_res_switch_btn")
|
||||
|
||||
if opts.dimensions_and_batch_together:
|
||||
with gr.Column(elem_id="txt2img_column_batch"):
|
||||
batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id="txt2img_batch_count")
|
||||
@ -492,7 +493,7 @@ def create_ui():
|
||||
seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox = create_seed_inputs('txt2img')
|
||||
|
||||
elif category == "checkboxes":
|
||||
with FormRow(elem_id="txt2img_checkboxes", variant="compact"):
|
||||
with FormRow(elem_classes="checkboxes-row", variant="compact"):
|
||||
restore_faces = gr.Checkbox(label='Restore faces', value=False, visible=len(shared.face_restorers) > 1, elem_id="txt2img_restore_faces")
|
||||
tiling = gr.Checkbox(label='Tiling', value=False, elem_id="txt2img_tiling")
|
||||
enable_hr = gr.Checkbox(label='Hires. fix', value=False, elem_id="txt2img_enable_hr")
|
||||
@ -586,7 +587,7 @@ def create_ui():
|
||||
txt2img_prompt.submit(**txt2img_args)
|
||||
submit.click(**txt2img_args)
|
||||
|
||||
res_switch_btn.click(lambda w, h: (h, w), inputs=[width, height], outputs=[width, height])
|
||||
res_switch_btn.click(lambda w, h: (h, w), inputs=[width, height], outputs=[width, height], show_progress=False)
|
||||
|
||||
txt_prompt_img.change(
|
||||
fn=modules.images.image_data,
|
||||
@ -757,7 +758,9 @@ def create_ui():
|
||||
width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="img2img_width")
|
||||
height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="img2img_height")
|
||||
|
||||
res_switch_btn = ToolButton(value=switch_values_symbol, elem_id="img2img_res_switch_btn")
|
||||
with gr.Column(elem_id="img2img_dimensions_row", scale=1, elem_classes="dimensions-tools"):
|
||||
res_switch_btn = ToolButton(value=switch_values_symbol, elem_id="img2img_res_switch_btn")
|
||||
|
||||
if opts.dimensions_and_batch_together:
|
||||
with gr.Column(elem_id="img2img_column_batch"):
|
||||
batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id="img2img_batch_count")
|
||||
@ -774,7 +777,7 @@ def create_ui():
|
||||
seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox = create_seed_inputs('img2img')
|
||||
|
||||
elif category == "checkboxes":
|
||||
with FormRow(elem_id="img2img_checkboxes", variant="compact"):
|
||||
with FormRow(elem_classes="checkboxes-row", variant="compact"):
|
||||
restore_faces = gr.Checkbox(label='Restore faces', value=False, visible=len(shared.face_restorers) > 1, elem_id="img2img_restore_faces")
|
||||
tiling = gr.Checkbox(label='Tiling', value=False, elem_id="img2img_tiling")
|
||||
|
||||
@ -904,7 +907,7 @@ def create_ui():
|
||||
|
||||
img2img_prompt.submit(**img2img_args)
|
||||
submit.click(**img2img_args)
|
||||
res_switch_btn.click(lambda w, h: (h, w), inputs=[width, height], outputs=[width, height])
|
||||
res_switch_btn.click(lambda w, h: (h, w), inputs=[width, height], outputs=[width, height], show_progress=False)
|
||||
|
||||
img2img_interrogate.click(
|
||||
fn=lambda *args: process_interrogate(interrogate, *args),
|
||||
@ -1491,11 +1494,33 @@ def create_ui():
|
||||
request_notifications = gr.Button(value='Request browser notifications', elem_id="request_notifications")
|
||||
download_localization = gr.Button(value='Download localization template', elem_id="download_localization")
|
||||
reload_script_bodies = gr.Button(value='Reload custom script bodies (No ui updates, No restart)', variant='secondary', elem_id="settings_reload_script_bodies")
|
||||
with gr.Row():
|
||||
unload_sd_model = gr.Button(value='Unload SD checkpoint to free VRAM', elem_id="sett_unload_sd_model")
|
||||
reload_sd_model = gr.Button(value='Reload the last SD checkpoint back into VRAM', elem_id="sett_reload_sd_model")
|
||||
|
||||
with gr.TabItem("Licenses"):
|
||||
gr.HTML(shared.html("licenses.html"), elem_id="licenses")
|
||||
|
||||
gr.Button(value="Show all pages", elem_id="settings_show_all_pages")
|
||||
|
||||
|
||||
def unload_sd_weights():
|
||||
modules.sd_models.unload_model_weights()
|
||||
|
||||
def reload_sd_weights():
|
||||
modules.sd_models.reload_model_weights()
|
||||
|
||||
unload_sd_model.click(
|
||||
fn=unload_sd_weights,
|
||||
inputs=[],
|
||||
outputs=[]
|
||||
)
|
||||
|
||||
reload_sd_model.click(
|
||||
fn=reload_sd_weights,
|
||||
inputs=[],
|
||||
outputs=[]
|
||||
)
|
||||
|
||||
request_notifications.click(
|
||||
fn=lambda: None,
|
||||
@ -1598,11 +1623,13 @@ def create_ui():
|
||||
|
||||
for i, k, item in quicksettings_list:
|
||||
component = component_dict[k]
|
||||
info = opts.data_labels[k]
|
||||
|
||||
component.change(
|
||||
fn=lambda value, k=k: run_settings_single(value, key=k),
|
||||
inputs=[component],
|
||||
outputs=[component, text_settings],
|
||||
show_progress=info.refresh is not None,
|
||||
)
|
||||
|
||||
text_settings.change(
|
||||
|
@ -129,8 +129,8 @@ Requested path was: {f}
|
||||
|
||||
generation_info = None
|
||||
with gr.Column():
|
||||
with gr.Row(elem_id=f"image_buttons_{tabname}"):
|
||||
open_folder_button = gr.Button(folder_symbol, elem_id="hidden_element" if shared.cmd_opts.hide_ui_dir_config else f'open_folder_{tabname}')
|
||||
with gr.Row(elem_id=f"image_buttons_{tabname}", elem_classes="image-buttons"):
|
||||
open_folder_button = gr.Button(folder_symbol, visible=not shared.cmd_opts.hide_ui_dir_config)
|
||||
|
||||
if tabname != "extras":
|
||||
save = gr.Button('Save', elem_id=f'save_{tabname}')
|
||||
@ -149,7 +149,7 @@ Requested path was: {f}
|
||||
download_files = gr.File(None, file_count="multiple", interactive=False, show_label=False, visible=False, elem_id=f'download_files_{tabname}')
|
||||
|
||||
with gr.Group():
|
||||
html_info = gr.HTML(elem_id=f'html_info_{tabname}')
|
||||
html_info = gr.HTML(elem_id=f'html_info_{tabname}', elem_classes="infotext")
|
||||
html_log = gr.HTML(elem_id=f'html_log_{tabname}')
|
||||
|
||||
generation_info = gr.Textbox(visible=False, elem_id=f'generation_info_{tabname}')
|
||||
@ -160,6 +160,7 @@ Requested path was: {f}
|
||||
_js="function(x, y, z){ return [x, y, selected_gallery_index()] }",
|
||||
inputs=[generation_info, html_info, html_info],
|
||||
outputs=[html_info, html_info],
|
||||
show_progress=False,
|
||||
)
|
||||
|
||||
save.click(
|
||||
@ -195,7 +196,7 @@ Requested path was: {f}
|
||||
|
||||
else:
|
||||
html_info_x = gr.HTML(elem_id=f'html_info_x_{tabname}')
|
||||
html_info = gr.HTML(elem_id=f'html_info_{tabname}')
|
||||
html_info = gr.HTML(elem_id=f'html_info_{tabname}', elem_classes="infotext")
|
||||
html_log = gr.HTML(elem_id=f'html_log_{tabname}')
|
||||
|
||||
paste_field_names = []
|
||||
|
@ -1,55 +1,61 @@
|
||||
import gradio as gr
|
||||
|
||||
|
||||
class ToolButton(gr.Button, gr.components.FormComponent):
|
||||
class FormComponent:
|
||||
def get_expected_parent(self):
|
||||
return gr.components.Form
|
||||
|
||||
|
||||
gr.Dropdown.get_expected_parent = FormComponent.get_expected_parent
|
||||
|
||||
|
||||
class ToolButton(FormComponent, gr.Button):
|
||||
"""Small button with single emoji as text, fits inside gradio forms"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(variant="tool", **kwargs)
|
||||
def __init__(self, *args, **kwargs):
|
||||
classes = kwargs.pop("elem_classes", [])
|
||||
super().__init__(*args, elem_classes=["tool", *classes], **kwargs)
|
||||
|
||||
def get_block_name(self):
|
||||
return "button"
|
||||
|
||||
|
||||
class ToolButtonTop(gr.Button, gr.components.FormComponent):
|
||||
"""Small button with single emoji as text, with extra margin at top, fits inside gradio forms"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(variant="tool-top", **kwargs)
|
||||
|
||||
def get_block_name(self):
|
||||
return "button"
|
||||
|
||||
|
||||
class FormRow(gr.Row, gr.components.FormComponent):
|
||||
class FormRow(FormComponent, gr.Row):
|
||||
"""Same as gr.Row but fits inside gradio forms"""
|
||||
|
||||
def get_block_name(self):
|
||||
return "row"
|
||||
|
||||
|
||||
class FormGroup(gr.Group, gr.components.FormComponent):
|
||||
class FormColumn(FormComponent, gr.Column):
|
||||
"""Same as gr.Column but fits inside gradio forms"""
|
||||
|
||||
def get_block_name(self):
|
||||
return "column"
|
||||
|
||||
|
||||
class FormGroup(FormComponent, gr.Group):
|
||||
"""Same as gr.Row but fits inside gradio forms"""
|
||||
|
||||
def get_block_name(self):
|
||||
return "group"
|
||||
|
||||
|
||||
class FormHTML(gr.HTML, gr.components.FormComponent):
|
||||
class FormHTML(FormComponent, gr.HTML):
|
||||
"""Same as gr.HTML but fits inside gradio forms"""
|
||||
|
||||
def get_block_name(self):
|
||||
return "html"
|
||||
|
||||
|
||||
class FormColorPicker(gr.ColorPicker, gr.components.FormComponent):
|
||||
class FormColorPicker(FormComponent, gr.ColorPicker):
|
||||
"""Same as gr.ColorPicker but fits inside gradio forms"""
|
||||
|
||||
def get_block_name(self):
|
||||
return "colorpicker"
|
||||
|
||||
|
||||
class DropdownMulti(gr.Dropdown):
|
||||
class DropdownMulti(FormComponent, gr.Dropdown):
|
||||
"""Same as gr.Dropdown but always multiselect"""
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(multiselect=True, **kwargs)
|
||||
|
@ -1,6 +1,5 @@
|
||||
import json
|
||||
import os.path
|
||||
import shutil
|
||||
import sys
|
||||
import time
|
||||
import traceback
|
||||
@ -141,22 +140,20 @@ def install_extension_from_url(dirname, url):
|
||||
|
||||
try:
|
||||
shutil.rmtree(tmpdir, True)
|
||||
|
||||
repo = git.Repo.clone_from(url, tmpdir)
|
||||
repo.remote().fetch()
|
||||
|
||||
with git.Repo.clone_from(url, tmpdir) as repo:
|
||||
repo.remote().fetch()
|
||||
for submodule in repo.submodules:
|
||||
submodule.update()
|
||||
try:
|
||||
os.rename(tmpdir, target_dir)
|
||||
except OSError as err:
|
||||
# TODO what does this do on windows? I think it'll be a different error code but I don't have a system to check it
|
||||
# Shouldn't cause any new issues at least but we probably want to handle it there too.
|
||||
if err.errno == errno.EXDEV:
|
||||
# Cross device link, typical in docker or when tmp/ and extensions/ are on different file systems
|
||||
# Since we can't use a rename, do the slower but more versitile shutil.move()
|
||||
shutil.move(tmpdir, target_dir)
|
||||
else:
|
||||
# Something else, not enough free space, permissions, etc. rethrow it so that it gets handled.
|
||||
raise(err)
|
||||
raise err
|
||||
|
||||
import launch
|
||||
launch.run_extension_installer(target_dir)
|
||||
@ -167,12 +164,12 @@ def install_extension_from_url(dirname, url):
|
||||
shutil.rmtree(tmpdir, True)
|
||||
|
||||
|
||||
def install_extension_from_index(url, hide_tags, sort_column):
|
||||
def install_extension_from_index(url, hide_tags, sort_column, filter_text):
|
||||
ext_table, message = install_extension_from_url(None, url)
|
||||
|
||||
code, _ = refresh_available_extensions_from_data(hide_tags, sort_column)
|
||||
code, _ = refresh_available_extensions_from_data(hide_tags, sort_column, filter_text)
|
||||
|
||||
return code, ext_table, message
|
||||
return code, ext_table, message, ''
|
||||
|
||||
|
||||
def refresh_available_extensions(url, hide_tags, sort_column):
|
||||
@ -186,11 +183,17 @@ def refresh_available_extensions(url, hide_tags, sort_column):
|
||||
|
||||
code, tags = refresh_available_extensions_from_data(hide_tags, sort_column)
|
||||
|
||||
return url, code, gr.CheckboxGroup.update(choices=tags), ''
|
||||
return url, code, gr.CheckboxGroup.update(choices=tags), '', ''
|
||||
|
||||
|
||||
def refresh_available_extensions_for_tags(hide_tags, sort_column):
|
||||
code, _ = refresh_available_extensions_from_data(hide_tags, sort_column)
|
||||
def refresh_available_extensions_for_tags(hide_tags, sort_column, filter_text):
|
||||
code, _ = refresh_available_extensions_from_data(hide_tags, sort_column, filter_text)
|
||||
|
||||
return code, ''
|
||||
|
||||
|
||||
def search_extensions(filter_text, hide_tags, sort_column):
|
||||
code, _ = refresh_available_extensions_from_data(hide_tags, sort_column, filter_text)
|
||||
|
||||
return code, ''
|
||||
|
||||
@ -205,7 +208,7 @@ sort_ordering = [
|
||||
]
|
||||
|
||||
|
||||
def refresh_available_extensions_from_data(hide_tags, sort_column):
|
||||
def refresh_available_extensions_from_data(hide_tags, sort_column, filter_text=""):
|
||||
extlist = available_extensions["extensions"]
|
||||
installed_extension_urls = {normalize_git_url(extension.remote): extension.name for extension in extensions.extensions}
|
||||
|
||||
@ -244,7 +247,12 @@ def refresh_available_extensions_from_data(hide_tags, sort_column):
|
||||
hidden += 1
|
||||
continue
|
||||
|
||||
install_code = f"""<input onclick="install_extension_from_index(this, '{html.escape(url)}')" type="button" value="{"Install" if not existing else "Installed"}" {"disabled=disabled" if existing else ""} class="gr-button gr-button-lg gr-button-secondary">"""
|
||||
if filter_text and filter_text.strip():
|
||||
if filter_text.lower() not in html.escape(name).lower() and filter_text.lower() not in html.escape(description).lower():
|
||||
hidden += 1
|
||||
continue
|
||||
|
||||
install_code = f"""<button onclick="install_extension_from_index(this, '{html.escape(url)}')" {"disabled=disabled" if existing else ""} class="lg secondary gradio-button custom-button">{"Install" if not existing else "Installed"}</button>"""
|
||||
|
||||
tags_text = ", ".join([f"<span class='extension-tag' title='{tags.get(x, '')}'>{x}</span>" for x in extension_tags])
|
||||
|
||||
@ -312,30 +320,39 @@ def create_ui():
|
||||
hide_tags = gr.CheckboxGroup(value=["ads", "localization", "installed"], label="Hide extensions with tags", choices=["script", "ads", "localization", "installed"])
|
||||
sort_column = gr.Radio(value="newest first", label="Order", choices=["newest first", "oldest first", "a-z", "z-a", "internal order", ], type="index")
|
||||
|
||||
with gr.Row():
|
||||
search_extensions_text = gr.Text(label="Search").style(container=False)
|
||||
|
||||
install_result = gr.HTML()
|
||||
available_extensions_table = gr.HTML()
|
||||
|
||||
refresh_available_extensions_button.click(
|
||||
fn=modules.ui.wrap_gradio_call(refresh_available_extensions, extra_outputs=[gr.update(), gr.update(), gr.update()]),
|
||||
inputs=[available_extensions_index, hide_tags, sort_column],
|
||||
outputs=[available_extensions_index, available_extensions_table, hide_tags, install_result],
|
||||
outputs=[available_extensions_index, available_extensions_table, hide_tags, install_result, search_extensions_text],
|
||||
)
|
||||
|
||||
install_extension_button.click(
|
||||
fn=modules.ui.wrap_gradio_call(install_extension_from_index, extra_outputs=[gr.update(), gr.update()]),
|
||||
inputs=[extension_to_install, hide_tags, sort_column],
|
||||
inputs=[extension_to_install, hide_tags, sort_column, search_extensions_text],
|
||||
outputs=[available_extensions_table, extensions_table, install_result],
|
||||
)
|
||||
|
||||
search_extensions_text.change(
|
||||
fn=modules.ui.wrap_gradio_call(search_extensions, extra_outputs=[gr.update()]),
|
||||
inputs=[search_extensions_text, hide_tags, sort_column],
|
||||
outputs=[available_extensions_table, install_result],
|
||||
)
|
||||
|
||||
hide_tags.change(
|
||||
fn=modules.ui.wrap_gradio_call(refresh_available_extensions_for_tags, extra_outputs=[gr.update()]),
|
||||
inputs=[hide_tags, sort_column],
|
||||
inputs=[hide_tags, sort_column, search_extensions_text],
|
||||
outputs=[available_extensions_table, install_result]
|
||||
)
|
||||
|
||||
sort_column.change(
|
||||
fn=modules.ui.wrap_gradio_call(refresh_available_extensions_for_tags, extra_outputs=[gr.update()]),
|
||||
inputs=[hide_tags, sort_column],
|
||||
inputs=[hide_tags, sort_column, search_extensions_text],
|
||||
outputs=[available_extensions_table, install_result]
|
||||
)
|
||||
|
||||
|
@ -22,21 +22,37 @@ def register_page(page):
|
||||
allowed_dirs.update(set(sum([x.allowed_directories_for_previews() for x in extra_pages], [])))
|
||||
|
||||
|
||||
def fetch_file(filename: str = ""):
|
||||
from starlette.responses import FileResponse
|
||||
|
||||
if not any([Path(x).absolute() in Path(filename).absolute().parents for x in allowed_dirs]):
|
||||
raise ValueError(f"File cannot be fetched: {filename}. Must be in one of directories registered by extra pages.")
|
||||
|
||||
ext = os.path.splitext(filename)[1].lower()
|
||||
if ext not in (".png", ".jpg", ".webp"):
|
||||
raise ValueError(f"File cannot be fetched: {filename}. Only png and jpg and webp.")
|
||||
|
||||
# would profit from returning 304
|
||||
return FileResponse(filename, headers={"Accept-Ranges": "bytes"})
|
||||
|
||||
|
||||
def get_metadata(page: str = "", item: str = ""):
|
||||
from starlette.responses import JSONResponse
|
||||
|
||||
page = next(iter([x for x in extra_pages if x.name == page]), None)
|
||||
if page is None:
|
||||
return JSONResponse({})
|
||||
|
||||
metadata = page.metadata.get(item)
|
||||
if metadata is None:
|
||||
return JSONResponse({})
|
||||
|
||||
return JSONResponse({"metadata": metadata})
|
||||
|
||||
|
||||
def add_pages_to_demo(app):
|
||||
def fetch_file(filename: str = ""):
|
||||
from starlette.responses import FileResponse
|
||||
|
||||
if not any([Path(x).absolute() in Path(filename).absolute().parents for x in allowed_dirs]):
|
||||
raise ValueError(f"File cannot be fetched: {filename}. Must be in one of directories registered by extra pages.")
|
||||
|
||||
ext = os.path.splitext(filename)[1].lower()
|
||||
if ext not in (".png", ".jpg"):
|
||||
raise ValueError(f"File cannot be fetched: {filename}. Only png and jpg.")
|
||||
|
||||
# would profit from returning 304
|
||||
return FileResponse(filename, headers={"Accept-Ranges": "bytes"})
|
||||
|
||||
app.add_api_route("/sd_extra_networks/thumb", fetch_file, methods=["GET"])
|
||||
app.add_api_route("/sd_extra_networks/metadata", get_metadata, methods=["GET"])
|
||||
|
||||
|
||||
class ExtraNetworksPage:
|
||||
@ -45,6 +61,7 @@ class ExtraNetworksPage:
|
||||
self.name = title.lower()
|
||||
self.card_page = shared.html("extra-networks-card.html")
|
||||
self.allow_negative_prompt = False
|
||||
self.metadata = {}
|
||||
|
||||
def refresh(self):
|
||||
pass
|
||||
@ -66,6 +83,8 @@ class ExtraNetworksPage:
|
||||
view = shared.opts.extra_networks_default_view
|
||||
items_html = ''
|
||||
|
||||
self.metadata = {}
|
||||
|
||||
subdirs = {}
|
||||
for parentdir in [os.path.abspath(x) for x in self.allowed_directories_for_previews()]:
|
||||
for x in glob.glob(os.path.join(parentdir, '**/*'), recursive=True):
|
||||
@ -86,12 +105,16 @@ class ExtraNetworksPage:
|
||||
subdirs = {"": 1, **subdirs}
|
||||
|
||||
subdirs_html = "".join([f"""
|
||||
<button class='gr-button gr-button-lg gr-button-secondary{" search-all" if subdir=="" else ""}' onclick='extraNetworksSearchButton("{tabname}_extra_tabs", event)'>
|
||||
<button class='lg secondary gradio-button custom-button{" search-all" if subdir=="" else ""}' onclick='extraNetworksSearchButton("{tabname}_extra_tabs", event)'>
|
||||
{html.escape(subdir if subdir!="" else "all")}
|
||||
</button>
|
||||
""" for subdir in subdirs])
|
||||
|
||||
for item in self.list_items():
|
||||
metadata = item.get("metadata")
|
||||
if metadata:
|
||||
self.metadata[item["name"]] = metadata
|
||||
|
||||
items_html += self.create_html_for_item(item, tabname)
|
||||
|
||||
if items_html == '':
|
||||
@ -124,8 +147,16 @@ class ExtraNetworksPage:
|
||||
if onclick is None:
|
||||
onclick = '"' + html.escape(f"""return cardClicked({json.dumps(tabname)}, {item["prompt"]}, {"true" if self.allow_negative_prompt else "false"})""") + '"'
|
||||
|
||||
height = f"height: {shared.opts.extra_networks_card_height}px;" if shared.opts.extra_networks_card_height else ''
|
||||
width = f"width: {shared.opts.extra_networks_card_width}px;" if shared.opts.extra_networks_card_width else ''
|
||||
background_image = f"background-image: url(\"{html.escape(preview)}\");" if preview else ''
|
||||
metadata_button = ""
|
||||
metadata = item.get("metadata")
|
||||
if metadata:
|
||||
metadata_button = f"<div class='metadata-button' title='Show metadata' onclick='extraNetworksRequestMetadata(event, {json.dumps(self.name)}, {json.dumps(item['name'])})'></div>"
|
||||
|
||||
args = {
|
||||
"preview_html": "style='background-image: url(\"" + html.escape(preview) + "\")'" if preview else '',
|
||||
"style": f"'{height}{width}{background_image}'",
|
||||
"prompt": item.get("prompt", None),
|
||||
"tabname": json.dumps(tabname),
|
||||
"local_preview": json.dumps(item["local_preview"]),
|
||||
@ -134,6 +165,7 @@ class ExtraNetworksPage:
|
||||
"card_clicked": onclick,
|
||||
"save_card_preview": '"' + html.escape(f"""return saveCardPreview(event, {json.dumps(tabname)}, {json.dumps(item["local_preview"])})""") + '"',
|
||||
"search_term": item.get("search_term", ""),
|
||||
"metadata_button": metadata_button,
|
||||
}
|
||||
|
||||
return self.card_page.format(**args)
|
||||
@ -208,6 +240,7 @@ def create_ui(container, button, tabname):
|
||||
with gr.Tabs(elem_id=tabname+"_extra_tabs") as tabs:
|
||||
for page in ui.stored_extra_pages:
|
||||
with gr.Tab(page.title):
|
||||
|
||||
page_elem = gr.HTML(page.create_html(ui.tabname))
|
||||
ui.pages.append(page_elem)
|
||||
|
||||
|
@ -4,7 +4,7 @@ basicsr
|
||||
fonts
|
||||
font-roboto
|
||||
gfpgan
|
||||
gradio==3.16.2
|
||||
gradio==3.23
|
||||
invisible-watermark
|
||||
numpy
|
||||
omegaconf
|
||||
@ -30,3 +30,4 @@ GitPython
|
||||
torchsde
|
||||
safetensors
|
||||
psutil
|
||||
rich
|
||||
|
@ -3,13 +3,13 @@ transformers==4.25.1
|
||||
accelerate==0.12.0
|
||||
basicsr==1.4.2
|
||||
gfpgan==1.3.8
|
||||
gradio==3.16.2
|
||||
gradio==3.23
|
||||
numpy==1.23.3
|
||||
Pillow==9.4.0
|
||||
realesrgan==0.3.0
|
||||
torch
|
||||
omegaconf==2.2.3
|
||||
pytorch_lightning==1.7.6
|
||||
pytorch_lightning==1.9.4
|
||||
scikit-image==0.19.2
|
||||
fonts
|
||||
font-roboto
|
||||
@ -25,6 +25,6 @@ lark==1.1.2
|
||||
inflection==0.5.1
|
||||
GitPython==3.1.30
|
||||
torchsde==0.2.5
|
||||
safetensors==0.2.7
|
||||
safetensors==0.3.0
|
||||
httpcore<=0.15
|
||||
fastapi==0.94.0
|
||||
|
@ -1,7 +1,9 @@
|
||||
function gradioApp() {
|
||||
const elems = document.getElementsByTagName('gradio-app')
|
||||
const gradioShadowRoot = elems.length == 0 ? null : elems[0].shadowRoot
|
||||
return !!gradioShadowRoot ? gradioShadowRoot : document;
|
||||
const elem = elems.length == 0 ? document : elems[0]
|
||||
|
||||
if (elem !== document) elem.getElementById = function(id){ return document.getElementById(id) }
|
||||
return elem.shadowRoot ? elem.shadowRoot : elem
|
||||
}
|
||||
|
||||
function get_uiCurrentTab() {
|
||||
|
@ -6,23 +6,21 @@ from tqdm import trange
|
||||
import modules.scripts as scripts
|
||||
import gradio as gr
|
||||
|
||||
from modules import processing, shared, sd_samplers, prompt_parser, sd_samplers_common
|
||||
from modules.processing import Processed
|
||||
from modules.shared import opts, cmd_opts, state
|
||||
from modules import processing, shared, sd_samplers, sd_samplers_common
|
||||
|
||||
import torch
|
||||
import k_diffusion as K
|
||||
|
||||
from PIL import Image
|
||||
from torch import autocast
|
||||
from einops import rearrange, repeat
|
||||
|
||||
|
||||
def find_noise_for_image(p, cond, uncond, cfg_scale, steps):
|
||||
x = p.init_latent
|
||||
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
dnw = K.external.CompVisDenoiser(shared.sd_model)
|
||||
if shared.sd_model.parameterization == "v":
|
||||
dnw = K.external.CompVisVDenoiser(shared.sd_model)
|
||||
skip = 1
|
||||
else:
|
||||
dnw = K.external.CompVisDenoiser(shared.sd_model)
|
||||
skip = 0
|
||||
sigmas = dnw.get_sigmas(steps).flip(0)
|
||||
|
||||
shared.state.sampling_steps = steps
|
||||
@ -37,7 +35,7 @@ def find_noise_for_image(p, cond, uncond, cfg_scale, steps):
|
||||
image_conditioning = torch.cat([p.image_conditioning] * 2)
|
||||
cond_in = {"c_concat": [image_conditioning], "c_crossattn": [cond_in]}
|
||||
|
||||
c_out, c_in = [K.utils.append_dims(k, x_in.ndim) for k in dnw.get_scalings(sigma_in)]
|
||||
c_out, c_in = [K.utils.append_dims(k, x_in.ndim) for k in dnw.get_scalings(sigma_in)[skip:]]
|
||||
t = dnw.sigma_to_t(sigma_in)
|
||||
|
||||
eps = shared.sd_model.apply_model(x_in * c_in, t, cond=cond_in)
|
||||
@ -69,7 +67,12 @@ def find_noise_for_image_sigma_adjustment(p, cond, uncond, cfg_scale, steps):
|
||||
x = p.init_latent
|
||||
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
dnw = K.external.CompVisDenoiser(shared.sd_model)
|
||||
if shared.sd_model.parameterization == "v":
|
||||
dnw = K.external.CompVisVDenoiser(shared.sd_model)
|
||||
skip = 1
|
||||
else:
|
||||
dnw = K.external.CompVisDenoiser(shared.sd_model)
|
||||
skip = 0
|
||||
sigmas = dnw.get_sigmas(steps).flip(0)
|
||||
|
||||
shared.state.sampling_steps = steps
|
||||
@ -84,7 +87,7 @@ def find_noise_for_image_sigma_adjustment(p, cond, uncond, cfg_scale, steps):
|
||||
image_conditioning = torch.cat([p.image_conditioning] * 2)
|
||||
cond_in = {"c_concat": [image_conditioning], "c_crossattn": [cond_in]}
|
||||
|
||||
c_out, c_in = [K.utils.append_dims(k, x_in.ndim) for k in dnw.get_scalings(sigma_in)]
|
||||
c_out, c_in = [K.utils.append_dims(k, x_in.ndim) for k in dnw.get_scalings(sigma_in)[skip:]]
|
||||
|
||||
if i == 1:
|
||||
t = dnw.sigma_to_t(torch.cat([sigmas[i] * s_in] * 2))
|
||||
@ -125,7 +128,7 @@ class Script(scripts.Script):
|
||||
def show(self, is_img2img):
|
||||
return is_img2img
|
||||
|
||||
def ui(self, is_img2img):
|
||||
def ui(self, is_img2img):
|
||||
info = gr.Markdown('''
|
||||
* `CFG Scale` should be 2 or lower.
|
||||
''')
|
||||
@ -213,4 +216,3 @@ class Script(scripts.Script):
|
||||
processed = processing.process_images(p)
|
||||
|
||||
return processed
|
||||
|
||||
|
@ -1,14 +1,10 @@
|
||||
import numpy as np
|
||||
from tqdm import trange
|
||||
import math
|
||||
|
||||
import modules.scripts as scripts
|
||||
import gradio as gr
|
||||
|
||||
from modules import processing, shared, sd_samplers, images
|
||||
import modules.scripts as scripts
|
||||
from modules import deepbooru, images, processing, shared
|
||||
from modules.processing import Processed
|
||||
from modules.sd_samplers import samplers
|
||||
from modules.shared import opts, cmd_opts, state
|
||||
from modules import deepbooru
|
||||
from modules.shared import opts, state
|
||||
|
||||
|
||||
class Script(scripts.Script):
|
||||
@ -20,39 +16,68 @@ class Script(scripts.Script):
|
||||
|
||||
def ui(self, is_img2img):
|
||||
loops = gr.Slider(minimum=1, maximum=32, step=1, label='Loops', value=4, elem_id=self.elem_id("loops"))
|
||||
denoising_strength_change_factor = gr.Slider(minimum=0.9, maximum=1.1, step=0.01, label='Denoising strength change factor', value=1, elem_id=self.elem_id("denoising_strength_change_factor"))
|
||||
final_denoising_strength = gr.Slider(minimum=0, maximum=1, step=0.01, label='Final denoising strength', value=0.5, elem_id=self.elem_id("final_denoising_strength"))
|
||||
denoising_curve = gr.Dropdown(label="Denoising strength curve", choices=["Aggressive", "Linear", "Lazy"], value="Linear")
|
||||
append_interrogation = gr.Dropdown(label="Append interrogated prompt at each iteration", choices=["None", "CLIP", "DeepBooru"], value="None")
|
||||
|
||||
return [loops, denoising_strength_change_factor, append_interrogation]
|
||||
return [loops, final_denoising_strength, denoising_curve, append_interrogation]
|
||||
|
||||
def run(self, p, loops, denoising_strength_change_factor, append_interrogation):
|
||||
def run(self, p, loops, final_denoising_strength, denoising_curve, append_interrogation):
|
||||
processing.fix_seed(p)
|
||||
batch_count = p.n_iter
|
||||
p.extra_generation_params = {
|
||||
"Denoising strength change factor": denoising_strength_change_factor,
|
||||
"Final denoising strength": final_denoising_strength,
|
||||
"Denoising curve": denoising_curve
|
||||
}
|
||||
|
||||
p.batch_size = 1
|
||||
p.n_iter = 1
|
||||
|
||||
output_images, info = None, None
|
||||
info = None
|
||||
initial_seed = None
|
||||
initial_info = None
|
||||
initial_denoising_strength = p.denoising_strength
|
||||
|
||||
grids = []
|
||||
all_images = []
|
||||
original_init_image = p.init_images
|
||||
original_prompt = p.prompt
|
||||
original_inpainting_fill = p.inpainting_fill
|
||||
state.job_count = loops * batch_count
|
||||
|
||||
initial_color_corrections = [processing.setup_color_correction(p.init_images[0])]
|
||||
|
||||
for n in range(batch_count):
|
||||
history = []
|
||||
def calculate_denoising_strength(loop):
|
||||
strength = initial_denoising_strength
|
||||
|
||||
if loops == 1:
|
||||
return strength
|
||||
|
||||
progress = loop / (loops - 1)
|
||||
match denoising_curve:
|
||||
case "Aggressive":
|
||||
strength = math.sin((progress) * math.pi * 0.5)
|
||||
|
||||
case "Lazy":
|
||||
strength = 1 - math.cos((progress) * math.pi * 0.5)
|
||||
|
||||
case _:
|
||||
strength = progress
|
||||
|
||||
change = (final_denoising_strength - initial_denoising_strength) * strength
|
||||
return initial_denoising_strength + change
|
||||
|
||||
history = []
|
||||
|
||||
for n in range(batch_count):
|
||||
# Reset to original init image at the start of each batch
|
||||
p.init_images = original_init_image
|
||||
|
||||
# Reset to original denoising strength
|
||||
p.denoising_strength = initial_denoising_strength
|
||||
|
||||
last_image = None
|
||||
|
||||
for i in range(loops):
|
||||
p.n_iter = 1
|
||||
p.batch_size = 1
|
||||
@ -72,26 +97,46 @@ class Script(scripts.Script):
|
||||
|
||||
processed = processing.process_images(p)
|
||||
|
||||
# Generation cancelled.
|
||||
if state.interrupted:
|
||||
break
|
||||
|
||||
if initial_seed is None:
|
||||
initial_seed = processed.seed
|
||||
initial_info = processed.info
|
||||
|
||||
init_img = processed.images[0]
|
||||
|
||||
p.init_images = [init_img]
|
||||
p.seed = processed.seed + 1
|
||||
p.denoising_strength = min(max(p.denoising_strength * denoising_strength_change_factor, 0.1), 1)
|
||||
history.append(processed.images[0])
|
||||
p.denoising_strength = calculate_denoising_strength(i + 1)
|
||||
|
||||
if state.skipped:
|
||||
break
|
||||
|
||||
last_image = processed.images[0]
|
||||
p.init_images = [last_image]
|
||||
p.inpainting_fill = 1 # Set "masked content" to "original" for next loop.
|
||||
|
||||
if batch_count == 1:
|
||||
history.append(last_image)
|
||||
all_images.append(last_image)
|
||||
|
||||
if batch_count > 1 and not state.skipped and not state.interrupted:
|
||||
history.append(last_image)
|
||||
all_images.append(last_image)
|
||||
|
||||
p.inpainting_fill = original_inpainting_fill
|
||||
|
||||
if state.interrupted:
|
||||
break
|
||||
|
||||
if len(history) > 1:
|
||||
grid = images.image_grid(history, rows=1)
|
||||
if opts.grid_save:
|
||||
images.save_image(grid, p.outpath_grids, "grid", initial_seed, p.prompt, opts.grid_format, info=info, short_filename=not opts.grid_extended_filename, grid=True, p=p)
|
||||
|
||||
grids.append(grid)
|
||||
all_images += history
|
||||
|
||||
if opts.return_grid:
|
||||
all_images = grids + all_images
|
||||
if opts.return_grid:
|
||||
grids.append(grid)
|
||||
|
||||
all_images = grids + all_images
|
||||
|
||||
processed = Processed(p, all_images, initial_seed, initial_info)
|
||||
|
||||
|
@ -17,22 +17,24 @@ class ScriptPostprocessingUpscale(scripts_postprocessing.ScriptPostprocessing):
|
||||
def ui(self):
|
||||
selected_tab = gr.State(value=0)
|
||||
|
||||
with gr.Tabs(elem_id="extras_resize_mode"):
|
||||
with gr.TabItem('Scale by', elem_id="extras_scale_by_tab") as tab_scale_by:
|
||||
upscaling_resize = gr.Slider(minimum=1.0, maximum=8.0, step=0.05, label="Resize", value=4, elem_id="extras_upscaling_resize")
|
||||
with gr.Column():
|
||||
with FormRow():
|
||||
with gr.Tabs(elem_id="extras_resize_mode"):
|
||||
with gr.TabItem('Scale by', elem_id="extras_scale_by_tab") as tab_scale_by:
|
||||
upscaling_resize = gr.Slider(minimum=1.0, maximum=8.0, step=0.05, label="Resize", value=4, elem_id="extras_upscaling_resize")
|
||||
|
||||
with gr.TabItem('Scale to', elem_id="extras_scale_to_tab") as tab_scale_to:
|
||||
with FormRow():
|
||||
upscaling_resize_w = gr.Number(label="Width", value=512, precision=0, elem_id="extras_upscaling_resize_w")
|
||||
upscaling_resize_h = gr.Number(label="Height", value=512, precision=0, elem_id="extras_upscaling_resize_h")
|
||||
upscaling_crop = gr.Checkbox(label='Crop to fit', value=True, elem_id="extras_upscaling_crop")
|
||||
with gr.TabItem('Scale to', elem_id="extras_scale_to_tab") as tab_scale_to:
|
||||
with FormRow():
|
||||
upscaling_resize_w = gr.Number(label="Width", value=512, precision=0, elem_id="extras_upscaling_resize_w")
|
||||
upscaling_resize_h = gr.Number(label="Height", value=512, precision=0, elem_id="extras_upscaling_resize_h")
|
||||
upscaling_crop = gr.Checkbox(label='Crop to fit', value=True, elem_id="extras_upscaling_crop")
|
||||
|
||||
with FormRow():
|
||||
extras_upscaler_1 = gr.Dropdown(label='Upscaler 1', elem_id="extras_upscaler_1", choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name)
|
||||
with FormRow():
|
||||
extras_upscaler_1 = gr.Dropdown(label='Upscaler 1', elem_id="extras_upscaler_1", choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name)
|
||||
|
||||
with FormRow():
|
||||
extras_upscaler_2 = gr.Dropdown(label='Upscaler 2', elem_id="extras_upscaler_2", choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name)
|
||||
extras_upscaler_2_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="Upscaler 2 visibility", value=0.0, elem_id="extras_upscaler_2_visibility")
|
||||
with FormRow():
|
||||
extras_upscaler_2 = gr.Dropdown(label='Upscaler 2', elem_id="extras_upscaler_2", choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name)
|
||||
extras_upscaler_2_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="Upscaler 2 visibility", value=0.0, elem_id="extras_upscaler_2_visibility")
|
||||
|
||||
tab_scale_by.select(fn=lambda: 0, inputs=[], outputs=[selected_tab])
|
||||
tab_scale_to.select(fn=lambda: 1, inputs=[], outputs=[selected_tab])
|
||||
|
@ -247,7 +247,7 @@ def draw_xyz_grid(p, xs, ys, zs, x_labels, y_labels, z_labels, cell, draw_legend
|
||||
|
||||
state.job = f"{index(ix, iy, iz) + 1} out of {list_size}"
|
||||
|
||||
processed: Processed = cell(x, y, z)
|
||||
processed: Processed = cell(x, y, z, ix, iy, iz)
|
||||
|
||||
if processed_result is None:
|
||||
# Use our first processed result object as a template container to hold our full results
|
||||
@ -515,6 +515,7 @@ class Script(scripts.Script):
|
||||
zs = process_axis(z_opt, z_values)
|
||||
|
||||
# this could be moved to common code, but unlikely to be ever triggered anywhere else
|
||||
Image.MAX_IMAGE_PIXELS = None # disable check in Pillow and rely on check below to allow large custom image sizes
|
||||
grid_mp = round(len(xs) * len(ys) * len(zs) * p.width * p.height / 1000000)
|
||||
assert grid_mp < opts.img_max_size_mp, f'Error: Resulting grid would be too large ({grid_mp} MPixels) (max configured size is {opts.img_max_size_mp} MPixels)'
|
||||
|
||||
@ -558,8 +559,6 @@ class Script(scripts.Script):
|
||||
print(f"X/Y/Z plot will create {len(xs) * len(ys) * len(zs) * image_cell_count} images on {len(zs)} {len(xs)}x{len(ys)} grid{plural_s}{cell_console_text}. (Total steps to process: {total_steps})")
|
||||
shared.total_tqdm.updateTotal(total_steps)
|
||||
|
||||
grid_infotext = [None]
|
||||
|
||||
state.xyz_plot_x = AxisInfo(x_opt, xs)
|
||||
state.xyz_plot_y = AxisInfo(y_opt, ys)
|
||||
state.xyz_plot_z = AxisInfo(z_opt, zs)
|
||||
@ -588,7 +587,9 @@ class Script(scripts.Script):
|
||||
else:
|
||||
second_axes_processed = 'y'
|
||||
|
||||
def cell(x, y, z):
|
||||
grid_infotext = [None] * (1 + len(zs))
|
||||
|
||||
def cell(x, y, z, ix, iy, iz):
|
||||
if shared.state.interrupted:
|
||||
return Processed(p, [], p.seed, "")
|
||||
|
||||
@ -600,7 +601,9 @@ class Script(scripts.Script):
|
||||
|
||||
res = process_images(pc)
|
||||
|
||||
if grid_infotext[0] is None:
|
||||
# Sets subgrid infotexts
|
||||
subgrid_index = 1 + iz
|
||||
if grid_infotext[subgrid_index] is None and ix == 0 and iy == 0:
|
||||
pc.extra_generation_params = copy(pc.extra_generation_params)
|
||||
pc.extra_generation_params['Script'] = self.title()
|
||||
|
||||
@ -616,6 +619,12 @@ class Script(scripts.Script):
|
||||
if y_opt.label in ["Seed", "Var. seed"] and not no_fixed_seeds:
|
||||
pc.extra_generation_params["Fixed Y Values"] = ", ".join([str(y) for y in ys])
|
||||
|
||||
grid_infotext[subgrid_index] = processing.create_infotext(pc, pc.all_prompts, pc.all_seeds, pc.all_subseeds)
|
||||
|
||||
# Sets main grid infotext
|
||||
if grid_infotext[0] is None and ix == 0 and iy == 0 and iz == 0:
|
||||
pc.extra_generation_params = copy(pc.extra_generation_params)
|
||||
|
||||
if z_opt.label != 'Nothing':
|
||||
pc.extra_generation_params["Z Type"] = z_opt.label
|
||||
pc.extra_generation_params["Z Values"] = z_values
|
||||
@ -650,6 +659,9 @@ class Script(scripts.Script):
|
||||
|
||||
z_count = len(zs)
|
||||
|
||||
# Set the grid infotexts to the real ones with extra_generation_params (1 main grid + z_count sub-grids)
|
||||
processed.infotexts[:1+z_count] = grid_infotext[:1+z_count]
|
||||
|
||||
if not include_lone_images:
|
||||
# Don't need sub-images anymore, drop from list:
|
||||
processed.images = processed.images[:z_count+1]
|
||||
|
8
webui.py
8
webui.py
@ -4,6 +4,7 @@ import time
|
||||
import importlib
|
||||
import signal
|
||||
import re
|
||||
import warnings
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.middleware.gzip import GZipMiddleware
|
||||
@ -17,6 +18,8 @@ from modules import paths, timer, import_hook, errors
|
||||
startup_timer = timer.Timer()
|
||||
|
||||
import torch
|
||||
import pytorch_lightning # pytorch_lightning should be imported after torch, but it re-enables warnings on import so import once to disable them
|
||||
warnings.filterwarnings(action="ignore", category=DeprecationWarning, module="pytorch_lightning")
|
||||
startup_timer.record("import torch")
|
||||
|
||||
import gradio
|
||||
@ -240,7 +243,7 @@ def webui():
|
||||
shared.demo = modules.ui.create_ui()
|
||||
startup_timer.record("create ui")
|
||||
|
||||
if cmd_opts.gradio_queue:
|
||||
if not cmd_opts.no_gradio_queue:
|
||||
shared.demo.queue(64)
|
||||
|
||||
gradio_auth_creds = []
|
||||
@ -262,6 +265,9 @@ def webui():
|
||||
inbrowser=cmd_opts.autolaunch,
|
||||
prevent_thread_lock=True
|
||||
)
|
||||
for dep in shared.demo.dependencies:
|
||||
dep['show_progress'] = False # disable gradio css animation on component update
|
||||
|
||||
# after initial launch, disable --autolaunch for subsequent restarts
|
||||
cmd_opts.autolaunch = False
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user