diff --git a/demo/calculator/run.py b/demo/calculator/run.py index 96aafc1638..265fff01d5 100644 --- a/demo/calculator/run.py +++ b/demo/calculator/run.py @@ -23,7 +23,6 @@ demo = gr.Interface( ], title="test calculator", description="heres a sample toy calculator. enjoy!", - flagging_options=["this", "or", "that"], ) if __name__ == "__main__": diff --git a/demo/image_mod/run.py b/demo/image_mod/run.py index f141cfcf1e..707cae3afe 100644 --- a/demo/image_mod/run.py +++ b/demo/image_mod/run.py @@ -5,7 +5,8 @@ def image_mod(image): return image.rotate(45) -demo = gr.Interface(image_mod, gr.inputs.Image(type="pil"), "image") +demo = gr.Interface(image_mod, gr.inputs.Image(type="pil"), "image", + flagging_options=["blurry", "incorrect", "other"]) if __name__ == "__main__": demo.launch() diff --git a/gradio/interface.py b/gradio/interface.py index 048bad5300..8b90632a23 100644 --- a/gradio/interface.py +++ b/gradio/interface.py @@ -438,6 +438,19 @@ class Interface(Blocks): ) if self.description: Markdown(self.description) + + def render_flag_btns(flagging_options): + if flagging_options is None: + return [(Button("Flag"), None)] + else: + return [ + ( + Button("Flag as " + flag_option), + flag_option, + ) + for flag_option in flagging_options + ] + with Row().style(equal_height=False): if self.interface_type in [ self.InterfaceTypes.STANDARD, @@ -472,7 +485,7 @@ class Interface(Blocks): clear_btn = Button("Clear") submit_btn = Button("Submit", variant="primary") if self.allow_flagging == "manual": - flag_btn = Button("Flag") + flag_btns = render_flag_btns(self.flagging_options) if self.interface_type in [ self.InterfaceTypes.STANDARD, @@ -488,7 +501,7 @@ class Interface(Blocks): clear_btn = Button("Clear") submit_btn = Button("Generate", variant="primary") if self.allow_flagging == "manual": - flag_btn = Button("Flag") + flag_btns = render_flag_btns(self.flagging_options) if self.interpretation: interpretation_btn = Button("Interpret") submit_fn = ( @@ -556,26 +569,34 @@ class Interface(Blocks): )} """, ) + + class FlagMethod: + def __init__(self, flagging_callback, flag_option=None): + self.flagging_callback = flagging_callback + self.flag_option = flag_option + + def __call__(self, *flag_data): + self.flagging_callback.flag(flag_data, flag_option=self.flag_option) + if self.allow_flagging == "manual": if self.interface_type in [ self.InterfaceTypes.STANDARD, self.InterfaceTypes.OUTPUT_ONLY, + self.InterfaceTypes.UNIFIED, ]: - flag_btn.click( - lambda *flag_data: self.flagging_callback.flag(flag_data), - inputs=self.input_components + self.output_components, - outputs=[], - _preprocess=False, - queue=False, - ) - elif self.interface_type == self.InterfaceTypes.UNIFIED: - flag_btn.click( - lambda *flag_data: self.flagging_callback.flag(flag_data), - inputs=self.input_components, - outputs=[], - _preprocess=False, - queue=False, - ) + if self.interface_type == self.InterfaceTypes.UNIFIED: + flag_components = self.input_components + else: + flag_components = self.input_components + self.output_components + for flag_btn, flag_option in flag_btns: + flag_method = FlagMethod(self.flagging_callback, flag_option) + flag_btn.click( + flag_method, + inputs=flag_components, + outputs=[], + _preprocess=False, + queue=False, + ) if self.examples: non_state_inputs = [