From d686e73daa6cca399fe68976922cabde681f69f1 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Wed, 26 Jun 2024 23:22:00 +0300 Subject: [PATCH] support for SD3: infinite prompt length, token counting --- modules/models/sd3/sd3_cond.py | 225 ++++++++++++++++++++++++++++++++ modules/models/sd3/sd3_model.py | 119 +---------------- modules/prompt_parser.py | 2 +- modules/sd_hijack.py | 5 +- modules/sd_hijack_clip.py | 59 ++++++--- modules/sd_models.py | 7 +- 6 files changed, 278 insertions(+), 139 deletions(-) create mode 100644 modules/models/sd3/sd3_cond.py diff --git a/modules/models/sd3/sd3_cond.py b/modules/models/sd3/sd3_cond.py new file mode 100644 index 000000000..c61ae0fe6 --- /dev/null +++ b/modules/models/sd3/sd3_cond.py @@ -0,0 +1,225 @@ +import os +import safetensors +import torch +import typing + +from transformers import CLIPTokenizer, T5TokenizerFast + +from modules import shared, devices, modelloader, sd_hijack_clip, prompt_parser +from modules.models.sd3.other_impls import SDClipModel, SDXLClipG, T5XXLModel, SD3Tokenizer + + +class SafetensorsMapping(typing.Mapping): + def __init__(self, file): + self.file = file + + def __len__(self): + return len(self.file.keys()) + + def __iter__(self): + for key in self.file.keys(): + yield key + + def __getitem__(self, key): + return self.file.get_tensor(key) + + +CLIPL_URL = "https://huggingface.co/AUTOMATIC/stable-diffusion-3-medium-text-encoders/resolve/main/clip_l.safetensors" +CLIPL_CONFIG = { + "hidden_act": "quick_gelu", + "hidden_size": 768, + "intermediate_size": 3072, + "num_attention_heads": 12, + "num_hidden_layers": 12, +} + +CLIPG_URL = "https://huggingface.co/AUTOMATIC/stable-diffusion-3-medium-text-encoders/resolve/main/clip_g.safetensors" +CLIPG_CONFIG = { + "hidden_act": "gelu", + "hidden_size": 1280, + "intermediate_size": 5120, + "num_attention_heads": 20, + "num_hidden_layers": 32, +} + +T5_URL = "https://huggingface.co/AUTOMATIC/stable-diffusion-3-medium-text-encoders/resolve/main/t5xxl_fp16.safetensors" +T5_CONFIG = { + "d_ff": 10240, + "d_model": 4096, + "num_heads": 64, + "num_layers": 24, + "vocab_size": 32128, +} + + +class Sd3ClipLG(sd_hijack_clip.TextConditionalModel): + def __init__(self, clip_l, clip_g): + super().__init__() + + self.clip_l = clip_l + self.clip_g = clip_g + + self.tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14") + + empty = self.tokenizer('')["input_ids"] + self.id_start = empty[0] + self.id_end = empty[1] + self.id_pad = empty[1] + + self.return_pooled = True + + def tokenize(self, texts): + return self.tokenizer(texts, truncation=False, add_special_tokens=False)["input_ids"] + + def encode_with_transformers(self, tokens): + tokens_g = tokens.clone() + + for batch_pos in range(tokens_g.shape[0]): + index = tokens_g[batch_pos].cpu().tolist().index(self.id_end) + tokens_g[batch_pos, index+1:tokens_g.shape[1]] = 0 + + l_out, l_pooled = self.clip_l(tokens) + g_out, g_pooled = self.clip_g(tokens_g) + + lg_out = torch.cat([l_out, g_out], dim=-1) + lg_out = torch.nn.functional.pad(lg_out, (0, 4096 - lg_out.shape[-1])) + + vector_out = torch.cat((l_pooled, g_pooled), dim=-1) + + lg_out.pooled = vector_out + return lg_out + + def encode_embedding_init_text(self, init_text, nvpt): + return torch.zeros((nvpt, 768+1280), device=devices.device) # XXX + + +class Sd3T5(torch.nn.Module): + def __init__(self, t5xxl): + super().__init__() + + self.t5xxl = t5xxl + self.tokenizer = T5TokenizerFast.from_pretrained("google/t5-v1_1-xxl") + + empty = self.tokenizer('', padding='max_length', max_length=2)["input_ids"] + self.id_end = empty[0] + self.id_pad = empty[1] + + def tokenize(self, texts): + return self.tokenizer(texts, truncation=False, add_special_tokens=False)["input_ids"] + + def tokenize_line(self, line, *, target_token_count=None): + if shared.opts.emphasis != "None": + parsed = prompt_parser.parse_prompt_attention(line) + else: + parsed = [[line, 1.0]] + + tokenized = self.tokenize([text for text, _ in parsed]) + + tokens = [] + multipliers = [] + + for text_tokens, (text, weight) in zip(tokenized, parsed): + if text == 'BREAK' and weight == -1: + continue + + tokens += text_tokens + multipliers += [weight] * len(text_tokens) + + tokens += [self.id_end] + multipliers += [1.0] + + if target_token_count is not None: + if len(tokens) < target_token_count: + tokens += [self.id_pad] * (target_token_count - len(tokens)) + multipliers += [1.0] * (target_token_count - len(tokens)) + else: + tokens = tokens[0:target_token_count] + multipliers = multipliers[0:target_token_count] + + return tokens, multipliers + + def forward(self, texts, *, token_count): + if not self.t5xxl or not shared.opts.sd3_enable_t5: + return torch.zeros((len(texts), token_count, 4096), device=devices.device, dtype=devices.dtype) + + tokens_batch = [] + + for text in texts: + tokens, multipliers = self.tokenize_line(text, target_token_count=token_count) + tokens_batch.append(tokens) + + t5_out, t5_pooled = self.t5xxl(tokens_batch) + + return t5_out + + def encode_embedding_init_text(self, init_text, nvpt): + return torch.zeros((nvpt, 4096), device=devices.device) # XXX + + +class SD3Cond(torch.nn.Module): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + self.tokenizer = SD3Tokenizer() + + with torch.no_grad(): + self.clip_g = SDXLClipG(CLIPG_CONFIG, device="cpu", dtype=devices.dtype) + self.clip_l = SDClipModel(layer="hidden", layer_idx=-2, device="cpu", dtype=devices.dtype, layer_norm_hidden_state=False, return_projected_pooled=False, textmodel_json_config=CLIPL_CONFIG) + + if shared.opts.sd3_enable_t5: + self.t5xxl = T5XXLModel(T5_CONFIG, device="cpu", dtype=devices.dtype) + else: + self.t5xxl = None + + self.model_lg = Sd3ClipLG(self.clip_l, self.clip_g) + self.model_t5 = Sd3T5(self.t5xxl) + + self.weights_loaded = False + + def forward(self, prompts: list[str]): + lg_out, vector_out = self.model_lg(prompts) + + token_count = lg_out.shape[1] + + t5_out = self.model_t5(prompts, token_count=token_count) + lgt_out = torch.cat([lg_out, t5_out], dim=-2) + + return { + 'crossattn': lgt_out, + 'vector': vector_out, + } + + def load_weights(self): + if self.weights_loaded: + return + + clip_path = os.path.join(shared.models_path, "CLIP") + + clip_g_file = modelloader.load_file_from_url(CLIPG_URL, model_dir=clip_path, file_name="clip_g.safetensors") + with safetensors.safe_open(clip_g_file, framework="pt") as file: + self.clip_g.transformer.load_state_dict(SafetensorsMapping(file)) + + clip_l_file = modelloader.load_file_from_url(CLIPL_URL, model_dir=clip_path, file_name="clip_l.safetensors") + with safetensors.safe_open(clip_l_file, framework="pt") as file: + self.clip_l.transformer.load_state_dict(SafetensorsMapping(file), strict=False) + + if self.t5xxl: + t5_file = modelloader.load_file_from_url(T5_URL, model_dir=clip_path, file_name="t5xxl_fp16.safetensors") + with safetensors.safe_open(t5_file, framework="pt") as file: + self.t5xxl.transformer.load_state_dict(SafetensorsMapping(file), strict=False) + + self.weights_loaded = True + + def encode_embedding_init_text(self, init_text, nvpt): + return torch.tensor([[0]], device=devices.device) # XXX + + def medvram_modules(self): + return [self.clip_g, self.clip_l, self.t5xxl] + + def get_token_count(self, text): + _, token_count = self.model_lg.process_texts([text]) + + return token_count + + def get_target_prompt_token_count(self, token_count): + return self.model_lg.get_target_prompt_token_count(token_count) diff --git a/modules/models/sd3/sd3_model.py b/modules/models/sd3/sd3_model.py index 309a7f863..10209f82a 100644 --- a/modules/models/sd3/sd3_model.py +++ b/modules/models/sd3/sd3_model.py @@ -1,127 +1,12 @@ import contextlib -import os -from typing import Mapping -import safetensors import torch import k_diffusion -from modules.models.sd3.other_impls import SDClipModel, SDXLClipG, T5XXLModel, SD3Tokenizer from modules.models.sd3.sd3_impls import BaseModel, SDVAE, SD3LatentFormat +from modules.models.sd3.sd3_cond import SD3Cond -from modules import shared, modelloader, devices - -CLIPG_URL = "https://huggingface.co/AUTOMATIC/stable-diffusion-3-medium-text-encoders/resolve/main/clip_g.safetensors" -CLIPG_CONFIG = { - "hidden_act": "gelu", - "hidden_size": 1280, - "intermediate_size": 5120, - "num_attention_heads": 20, - "num_hidden_layers": 32, -} - -CLIPL_URL = "https://huggingface.co/AUTOMATIC/stable-diffusion-3-medium-text-encoders/resolve/main/clip_l.safetensors" -CLIPL_CONFIG = { - "hidden_act": "quick_gelu", - "hidden_size": 768, - "intermediate_size": 3072, - "num_attention_heads": 12, - "num_hidden_layers": 12, -} - -T5_URL = "https://huggingface.co/AUTOMATIC/stable-diffusion-3-medium-text-encoders/resolve/main/t5xxl_fp16.safetensors" -T5_CONFIG = { - "d_ff": 10240, - "d_model": 4096, - "num_heads": 64, - "num_layers": 24, - "vocab_size": 32128, -} - - -class SafetensorsMapping(Mapping): - def __init__(self, file): - self.file = file - - def __len__(self): - return len(self.file.keys()) - - def __iter__(self): - for key in self.file.keys(): - yield key - - def __getitem__(self, key): - return self.file.get_tensor(key) - - -class SD3Cond(torch.nn.Module): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - self.tokenizer = SD3Tokenizer() - - with torch.no_grad(): - self.clip_g = SDXLClipG(CLIPG_CONFIG, device="cpu", dtype=devices.dtype) - self.clip_l = SDClipModel(layer="hidden", layer_idx=-2, device="cpu", dtype=devices.dtype, layer_norm_hidden_state=False, return_projected_pooled=False, textmodel_json_config=CLIPL_CONFIG) - - if shared.opts.sd3_enable_t5: - self.t5xxl = T5XXLModel(T5_CONFIG, device="cpu", dtype=devices.dtype) - else: - self.t5xxl = None - - self.weights_loaded = False - - def forward(self, prompts: list[str]): - res = [] - - for prompt in prompts: - tokens = self.tokenizer.tokenize_with_weights(prompt) - l_out, l_pooled = self.clip_l.encode_token_weights(tokens["l"]) - g_out, g_pooled = self.clip_g.encode_token_weights(tokens["g"]) - - if self.t5xxl and shared.opts.sd3_enable_t5: - t5_out, t5_pooled = self.t5xxl.encode_token_weights(tokens["t5xxl"]) - else: - t5_out = torch.zeros(l_out.shape[0:2] + (4096,), dtype=l_out.dtype, device=l_out.device) - - lg_out = torch.cat([l_out, g_out], dim=-1) - lg_out = torch.nn.functional.pad(lg_out, (0, 4096 - lg_out.shape[-1])) - lgt_out = torch.cat([lg_out, t5_out], dim=-2) - vector_out = torch.cat((l_pooled, g_pooled), dim=-1) - - res.append({ - 'crossattn': lgt_out[0].to(devices.device), - 'vector': vector_out[0].to(devices.device), - }) - - return res - - def load_weights(self): - if self.weights_loaded: - return - - clip_path = os.path.join(shared.models_path, "CLIP") - - clip_g_file = modelloader.load_file_from_url(CLIPG_URL, model_dir=clip_path, file_name="clip_g.safetensors") - with safetensors.safe_open(clip_g_file, framework="pt") as file: - self.clip_g.transformer.load_state_dict(SafetensorsMapping(file)) - - clip_l_file = modelloader.load_file_from_url(CLIPL_URL, model_dir=clip_path, file_name="clip_l.safetensors") - with safetensors.safe_open(clip_l_file, framework="pt") as file: - self.clip_l.transformer.load_state_dict(SafetensorsMapping(file), strict=False) - - if self.t5xxl: - t5_file = modelloader.load_file_from_url(T5_URL, model_dir=clip_path, file_name="t5xxl_fp16.safetensors") - with safetensors.safe_open(t5_file, framework="pt") as file: - self.t5xxl.transformer.load_state_dict(SafetensorsMapping(file), strict=False) - - self.weights_loaded = True - - def encode_embedding_init_text(self, init_text, nvpt): - return torch.tensor([[0]], device=devices.device) # XXX - - def medvram_modules(self): - return [self.clip_g, self.clip_l, self.t5xxl] +from modules import shared, devices class SD3Denoiser(k_diffusion.external.DiscreteSchedule): diff --git a/modules/prompt_parser.py b/modules/prompt_parser.py index cba134554..4e393d286 100644 --- a/modules/prompt_parser.py +++ b/modules/prompt_parser.py @@ -268,7 +268,7 @@ def get_multicond_learned_conditioning(model, prompts, steps, hires_steps=None, class DictWithShape(dict): - def __init__(self, x, shape): + def __init__(self, x, shape=None): super().__init__() self.update(x) diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index e139d9964..d5b2989f4 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -325,7 +325,10 @@ class StableDiffusionModelHijack: if self.clip is None: return "-", "-" - _, token_count = self.clip.process_texts([text]) + if hasattr(self.clip, 'get_token_count'): + token_count = self.clip.get_token_count(text) + else: + _, token_count = self.clip.process_texts([text]) return token_count, self.clip.get_target_prompt_token_count(token_count) diff --git a/modules/sd_hijack_clip.py b/modules/sd_hijack_clip.py index 355df3d30..a479148fc 100644 --- a/modules/sd_hijack_clip.py +++ b/modules/sd_hijack_clip.py @@ -27,24 +27,21 @@ chunk. Those objects are found in PromptChunk.fixes and, are placed into FrozenC are applied by sd_hijack.EmbeddingsWithFixes's forward function.""" -class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module): - """A pytorch module that is a wrapper for FrozenCLIPEmbedder module. it enhances FrozenCLIPEmbedder, making it possible to - have unlimited prompt length and assign weights to tokens in prompt. - """ - - def __init__(self, wrapped, hijack): +class TextConditionalModel(torch.nn.Module): + def __init__(self): super().__init__() - self.wrapped = wrapped - """Original FrozenCLIPEmbedder module; can also be FrozenOpenCLIPEmbedder or xlmr.BertSeriesModelWithTransformation, - depending on model.""" - - self.hijack: sd_hijack.StableDiffusionModelHijack = hijack + self.hijack = sd_hijack.model_hijack self.chunk_length = 75 - self.is_trainable = getattr(wrapped, 'is_trainable', False) - self.input_key = getattr(wrapped, 'input_key', 'txt') - self.legacy_ucg_val = None + self.is_trainable = False + self.input_key = 'txt' + self.return_pooled = False + + self.comma_token = None + self.id_start = None + self.id_end = None + self.id_pad = None def empty_chunk(self): """creates an empty PromptChunk and returns it""" @@ -210,10 +207,6 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module): is when you do prompt editing: "a picture of a [cat:dog:0.4] eating ice cream" """ - if opts.use_old_emphasis_implementation: - import modules.sd_hijack_clip_old - return modules.sd_hijack_clip_old.forward_old(self, texts) - batch_chunks, token_count = self.process_texts(texts) used_embeddings = {} @@ -252,7 +245,7 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module): if any(x for x in texts if "(" in x or "[" in x) and opts.emphasis != "Original": self.hijack.extra_generation_params["Emphasis"] = opts.emphasis - if getattr(self.wrapped, 'return_pooled', False): + if self.return_pooled: return torch.hstack(zs), zs[0].pooled else: return torch.hstack(zs) @@ -292,6 +285,34 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module): return z +class FrozenCLIPEmbedderWithCustomWordsBase(TextConditionalModel): + """A pytorch module that is a wrapper for FrozenCLIPEmbedder module. it enhances FrozenCLIPEmbedder, making it possible to + have unlimited prompt length and assign weights to tokens in prompt. + """ + + def __init__(self, wrapped, hijack): + super().__init__() + + self.hijack = hijack + + self.wrapped = wrapped + """Original FrozenCLIPEmbedder module; can also be FrozenOpenCLIPEmbedder or xlmr.BertSeriesModelWithTransformation, + depending on model.""" + + self.is_trainable = getattr(wrapped, 'is_trainable', False) + self.input_key = getattr(wrapped, 'input_key', 'txt') + self.return_pooled = getattr(self.wrapped, 'return_pooled', False) + + self.legacy_ucg_val = None # for sgm codebase + + def forward(self, texts): + if opts.use_old_emphasis_implementation: + import modules.sd_hijack_clip_old + return modules.sd_hijack_clip_old.forward_old(self, texts) + + return super().forward(texts) + + class FrozenCLIPEmbedderWithCustomWords(FrozenCLIPEmbedderWithCustomWordsBase): def __init__(self, wrapped, hijack): super().__init__(wrapped, hijack) diff --git a/modules/sd_models.py b/modules/sd_models.py index 61fb881ba..45575e440 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -722,7 +722,12 @@ def get_empty_cond(sd_model): d = sd_model.get_learned_conditioning([""]) return d['crossattn'] else: - return sd_model.cond_stage_model([""]) + d = sd_model.cond_stage_model([""]) + + if isinstance(d, dict): + d = d['crossattn'] + + return d def send_model_to_cpu(m):