From 3b2f51602dbd1a8a94706ae71943403f07539b1c Mon Sep 17 00:00:00 2001
From: AnyISalIn <anyisalin@gmail.com>
Date: Fri, 11 Aug 2023 20:21:38 +0800
Subject: [PATCH] xyz_grid: support refiner_checkpoint and refiner_switch_at

Signed-off-by: AnyISalIn <anyisalin@gmail.com>
---
 scripts/xyz_grid.py | 19 +++++++++++++++++++
 1 file changed, 19 insertions(+)

diff --git a/scripts/xyz_grid.py b/scripts/xyz_grid.py
index d37b428fc..a45e6d611 100644
--- a/scripts/xyz_grid.py
+++ b/scripts/xyz_grid.py
@@ -85,6 +85,23 @@ def confirm_checkpoints(p, xs):
         if modules.sd_models.get_closet_checkpoint_match(x) is None:
             raise RuntimeError(f"Unknown checkpoint: {x}")
 
+def apply_refiner_checkpoint(p, x, xs):
+    if x == 'None':
+        p.override_settings['sd_refiner_checkpoint'] = 'None'
+        return
+
+    info = modules.sd_models.get_closet_checkpoint_match(x)
+    if info is None:
+        raise RuntimeError(f"Unknown checkpoint: {x}")
+    p.override_settings['sd_refiner_checkpoint'] = info.name
+
+def confirm_refiner_checkpoints(p, xs):
+    for x in xs:
+        if x == 'None':
+            continue
+        if modules.sd_models.get_closet_checkpoint_match(x) is None:
+            raise RuntimeError(f"Unknown checkpoint: {x}")
+
 
 def apply_clip_skip(p, x, xs):
     opts.data["CLIP_stop_at_last_layers"] = x
@@ -241,6 +258,8 @@ axis_options = [
     AxisOption("Token merging ratio", float, apply_override('token_merging_ratio')),
     AxisOption("Token merging ratio high-res", float, apply_override('token_merging_ratio_hr')),
     AxisOption("Always discard next-to-last sigma", str, apply_override('always_discard_next_to_last_sigma', boolean=True), choices=boolean_choice(reverse=True)),
+    AxisOption("Refiner checkpoint", str, apply_refiner_checkpoint, format_value=format_remove_path, confirm=confirm_refiner_checkpoints, cost=1.0, choices=lambda: ['None'] + sorted(sd_models.checkpoints_list, key=str.casefold)),
+    AxisOption("Refiner switch at", float, apply_override('sd_refiner_switch_at'))
 ]