2023-02-18 08:53:21 +08:00
|
|
|
import torch
|
2024-07-20 09:34:34 +08:00
|
|
|
from diffusers import DiffusionPipeline # type: ignore
|
2023-02-18 08:53:21 +08:00
|
|
|
import gradio as gr
|
|
|
|
|
|
|
|
generator = DiffusionPipeline.from_pretrained("CompVis/ldm-text2im-large-256")
|
|
|
|
# move to GPU if available
|
|
|
|
if torch.cuda.is_available():
|
|
|
|
generator = generator.to("cuda")
|
|
|
|
|
|
|
|
def generate(prompts):
|
2024-07-20 09:34:34 +08:00
|
|
|
images = generator(list(prompts)).images # type: ignore
|
2023-02-18 08:53:21 +08:00
|
|
|
return [images]
|
|
|
|
|
2024-07-30 13:08:51 +08:00
|
|
|
demo = gr.Interface(generate,
|
|
|
|
"textbox",
|
|
|
|
"image",
|
|
|
|
batch=True,
|
2023-02-18 08:53:21 +08:00
|
|
|
max_batch_size=4 # Set the batch size based on your CPU/GPU memory
|
2024-07-30 13:08:51 +08:00
|
|
|
)
|
2023-02-18 08:53:21 +08:00
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
demo.launch()
|