diff --git a/modules/textual_inversion/dataset.py b/modules/textual_inversion/dataset.py index bcf772d2f..d4baf0660 100644 --- a/modules/textual_inversion/dataset.py +++ b/modules/textual_inversion/dataset.py @@ -22,6 +22,7 @@ class PersonalizedBase(Dataset): self.width = width self.height = height self.flip = transforms.RandomHorizontalFlip(p=flip_p) + self.extns = [".jpg",".jpeg",".png"] self.dataset = [] @@ -32,7 +33,7 @@ class PersonalizedBase(Dataset): assert data_root, 'dataset directory not specified' - self.image_paths = [os.path.join(data_root, file_path) for file_path in os.listdir(data_root)] + self.image_paths = [os.path.join(data_root, file_path) for file_path in os.listdir(data_root) if os.path.splitext(file_path.casefold())[1] in self.extns] print("Preparing dataset...") for path in tqdm.tqdm(self.image_paths): image = Image.open(path) diff --git a/modules/textual_inversion/preprocess.py b/modules/textual_inversion/preprocess.py index d7efdef29..b6c78cf85 100644 --- a/modules/textual_inversion/preprocess.py +++ b/modules/textual_inversion/preprocess.py @@ -12,12 +12,13 @@ def preprocess(process_src, process_dst, process_width, process_height, process_ height = process_height src = os.path.abspath(process_src) dst = os.path.abspath(process_dst) + extns = [".jpg",".jpeg",".png"] assert src != dst, 'same directory specified as source and destination' os.makedirs(dst, exist_ok=True) - files = os.listdir(src) + files = [i for i in os.listdir(src) if os.path.splitext(i.casefold())[1] in extns] shared.state.textinfo = "Preprocessing..." shared.state.job_count = len(files) diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index 5965c5a06..45397be92 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -161,6 +161,7 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini shared.state.textinfo = "Initializing textual inversion training..." shared.state.job_count = steps + extns = [".jpg",".jpeg",".png"] filename = os.path.join(shared.cmd_opts.embeddings_dir, f'{embedding_name}.pt') @@ -200,7 +201,7 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini if ititial_step > steps: return embedding, filename - tr_img_len = len([os.path.join(data_root, file_path) for file_path in os.listdir(data_root)]) + tr_img_len = len([os.path.join(data_root, file_path) for file_path in os.listdir(data_root) if os.path.splitext(file_path.casefold())[1] in extns]) epoch_len = (tr_img_len * num_repeats) + tr_img_len pbar = tqdm.tqdm(enumerate(ds), total=steps-ititial_step)