This commit is contained in:
aliabid94 2022-06-09 15:46:07 -07:00 committed by GitHub
parent 8544295df2
commit 293a5916cf
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 40 additions and 19 deletions

View File

@ -23,7 +23,6 @@ demo = gr.Interface(
],
title="test calculator",
description="heres a sample toy calculator. enjoy!",
flagging_options=["this", "or", "that"],
)
if __name__ == "__main__":

View File

@ -5,7 +5,8 @@ def image_mod(image):
return image.rotate(45)
demo = gr.Interface(image_mod, gr.inputs.Image(type="pil"), "image")
demo = gr.Interface(image_mod, gr.inputs.Image(type="pil"), "image",
flagging_options=["blurry", "incorrect", "other"])
if __name__ == "__main__":
demo.launch()

View File

@ -438,6 +438,19 @@ class Interface(Blocks):
)
if self.description:
Markdown(self.description)
def render_flag_btns(flagging_options):
if flagging_options is None:
return [(Button("Flag"), None)]
else:
return [
(
Button("Flag as " + flag_option),
flag_option,
)
for flag_option in flagging_options
]
with Row().style(equal_height=False):
if self.interface_type in [
self.InterfaceTypes.STANDARD,
@ -472,7 +485,7 @@ class Interface(Blocks):
clear_btn = Button("Clear")
submit_btn = Button("Submit", variant="primary")
if self.allow_flagging == "manual":
flag_btn = Button("Flag")
flag_btns = render_flag_btns(self.flagging_options)
if self.interface_type in [
self.InterfaceTypes.STANDARD,
@ -488,7 +501,7 @@ class Interface(Blocks):
clear_btn = Button("Clear")
submit_btn = Button("Generate", variant="primary")
if self.allow_flagging == "manual":
flag_btn = Button("Flag")
flag_btns = render_flag_btns(self.flagging_options)
if self.interpretation:
interpretation_btn = Button("Interpret")
submit_fn = (
@ -556,26 +569,34 @@ class Interface(Blocks):
)}
""",
)
class FlagMethod:
def __init__(self, flagging_callback, flag_option=None):
self.flagging_callback = flagging_callback
self.flag_option = flag_option
def __call__(self, *flag_data):
self.flagging_callback.flag(flag_data, flag_option=self.flag_option)
if self.allow_flagging == "manual":
if self.interface_type in [
self.InterfaceTypes.STANDARD,
self.InterfaceTypes.OUTPUT_ONLY,
self.InterfaceTypes.UNIFIED,
]:
flag_btn.click(
lambda *flag_data: self.flagging_callback.flag(flag_data),
inputs=self.input_components + self.output_components,
outputs=[],
_preprocess=False,
queue=False,
)
elif self.interface_type == self.InterfaceTypes.UNIFIED:
flag_btn.click(
lambda *flag_data: self.flagging_callback.flag(flag_data),
inputs=self.input_components,
outputs=[],
_preprocess=False,
queue=False,
)
if self.interface_type == self.InterfaceTypes.UNIFIED:
flag_components = self.input_components
else:
flag_components = self.input_components + self.output_components
for flag_btn, flag_option in flag_btns:
flag_method = FlagMethod(self.flagging_callback, flag_option)
flag_btn.click(
flag_method,
inputs=flag_components,
outputs=[],
_preprocess=False,
queue=False,
)
if self.examples:
non_state_inputs = [