diff --git a/gradio/interface.py b/gradio/interface.py index d5e8f49eae..186c932c51 100644 --- a/gradio/interface.py +++ b/gradio/interface.py @@ -40,7 +40,7 @@ class Interface: return list(Interface.instances) @classmethod - def load(cls, name, src=None, api_key=None, alias=None, **kwargs): + def load(cls, name: str, src: str = None, api_key: str = None, alias: str = None, **kwargs): """ Class method to construct an Interface from an external source repository, such as huggingface. Parameters: @@ -52,15 +52,15 @@ class Interface: (gradio.Interface): a Gradio Interface object for the given model """ interface_info = load_interface(name, src, api_key, alias) - # create a dictionary of kwargs without overwriting the original interface_info dict because it is mutable - # and that can cause some issues since the internal prediction function may rely on the original interface_info dict - kwargs = dict(interface_info, **kwargs) + # create a dictionary of kwargs without overwriting the original interface_info dict because it is mutable + # and that can cause some issues since the internal prediction function may rely on the original interface_info dict + kwargs = dict(interface_info, **kwargs) interface = cls(**kwargs) interface.api_mode = True # set api mode to true so that the interface will not preprocess/postprocess return interface @classmethod - def from_pipeline(cls, pipeline, **kwargs): + def from_pipeline(cls, pipeline: "transformers.Pipeline", **kwargs): """ Class method to construct an Interface from a Hugging Face transformers.Pipeline. pipeline (transformers.Pipeline): @@ -68,7 +68,7 @@ class Interface: (gradio.Interface): a Gradio Interface object from the given Pipeline """ interface_info = load_from_pipeline(pipeline) - kwargs = dict(interface_info, **kwargs) + kwargs = dict(interface_info, **kwargs) interface = cls(**kwargs) return interface @@ -76,7 +76,7 @@ class Interface: examples_per_page=10, live=False, layout="unaligned", show_input=True, show_output=True, capture_session=None, interpretation=None, num_shap=2.0, theme=None, repeat_outputs_per_model=True, title=None, description=None, article=None, thumbnail=None, - css=None, height=500, width=900, allow_screenshot=True, allow_flagging=None, flagging_options=None, + css=None, height=500, width=900, allow_screenshot=True, allow_flagging=None, flagging_options=None, encrypt=False, show_tips=None, flagging_dir="flagged", analytics_enabled=None, enable_queue=None, api_mode=None): """ Parameters: @@ -129,7 +129,7 @@ class Interface: self.predict_durations = [[0, 0]] * len(fn) self.function_names = [func.__name__ for func in fn] self.__name__ = ", ".join(self.function_names) - + if verbose is not None: warnings.warn("The `verbose` parameter in the `Interface` is deprecated and has no effect.") @@ -140,7 +140,7 @@ class Interface: self.show_output = show_output self.flag_hash = random.getrandbits(32) self.capture_session = capture_session - + if capture_session is not None: warnings.warn("The `capture_session` parameter in the `Interface` will be deprecated in the near future.") @@ -174,12 +174,12 @@ class Interface: "Examples argument must either be a directory or a nested list, where each sublist represents a set of inputs.") self.num_shap = num_shap self.examples_per_page = examples_per_page - + self.simple_server = None self.allow_screenshot = allow_screenshot # For allow_flagging and analytics_enabled: (1) first check for parameter, (2) check for environment variable, (3) default to True - self.allow_flagging = allow_flagging if allow_flagging is not None else os.getenv("GRADIO_ALLOW_FLAGGING", "True")=="True" - self.analytics_enabled = analytics_enabled if analytics_enabled is not None else os.getenv("GRADIO_ANALYTICS_ENABLED", "True")=="True" + self.allow_flagging = allow_flagging if allow_flagging is not None else os.getenv("GRADIO_ALLOW_FLAGGING", "True") == "True" + self.analytics_enabled = analytics_enabled if analytics_enabled is not None else os.getenv("GRADIO_ANALYTICS_ENABLED", "True") == "True" self.flagging_options = flagging_options self.flagging_dir = flagging_dir self.encrypt = encrypt @@ -188,13 +188,13 @@ class Interface: self.share_url = None self.local_url = None self.ip_address = networking.get_local_ip_address() - + if show_tips is not None: warnings.warn("The `show_tips` parameter in the `Interface` is deprecated. Please use the `show_tips` parameter in `launch()` instead") self.requires_permissions = any( [component.requires_permissions for component in self.input_components]) - + self.enable_queue = enable_queue if self.enable_queue is not None: warnings.warn("The `enable_queue` parameter in the `Interface` will be deprecated. Please use the `enable_queue` parameter in `launch()` instead") @@ -207,7 +207,7 @@ class Interface: try: import tensorflow as tf self.session = tf.get_default_graph(), \ - tf.keras.backend.get_session() + tf.keras.backend.get_session() except (ImportError, AttributeError): # If they are using TF >= 2.0 or don't have TF, # just ignore this. @@ -240,7 +240,7 @@ class Interface: if self.api_mode: # skip the preprocessing/postprocessing if sending to a remote API output = self.run_prediction(params, called_directly=True) else: - output, _ = self.process(params) + output, _ = self.process(params) return output[0] if len(output) == 1 else output def __str__(self): @@ -249,7 +249,7 @@ class Interface: def __repr__(self): repr = "Gradio Interface for: {}".format( ", ".join(fn.__name__ for fn in self.predict)) - repr += "\n" + "-"*len(repr) + repr += "\n" + "-" * len(repr) repr += "\ninputs:" for component in self.input_components: repr += "\n|-{}".format(str(component)) @@ -296,7 +296,7 @@ class Interface: function_index = i // outputs_per_function component_index = i - function_index * outputs_per_function ret_name = "Output " + \ - str(component_index + 1) if outputs_per_function > 1 else "Output" + str(component_index + 1) if outputs_per_function > 1 else "Output" if iface["label"] is None: iface["label"] = ret_name if len(self.predict) > 1: @@ -344,7 +344,7 @@ class Interface: """ if self.api_mode: # Serialize the input processed_input = [input_component.serialize(processed_input[i], called_directly) - for i, input_component in enumerate(self.input_components)] + for i, input_component in enumerate(self.input_components)] predictions = [] durations = [] output_component_counter = 0 @@ -395,7 +395,7 @@ class Interface: for i, input_component in enumerate(self.input_components)] predictions, durations = self.run_prediction( processed_input, return_duration=True) - processed_output = [output_component.postprocess(predictions[i]) if predictions[i] is not None else None + processed_output = [output_component.postprocess(predictions[i]) if predictions[i] is not None else None for i, output_component in enumerate(self.output_components)] return processed_output, durations @@ -467,7 +467,7 @@ class Interface: raise ValueError( "The package `shap` is required for this interpretation method. Try: `pip install shap`") input_component = self.input_components[i] - if not(input_component.interpret_by_tokens): + if not (input_component.interpret_by_tokens): raise ValueError( "Input component {} does not support `shap` interpretation".format(input_component)) @@ -494,7 +494,7 @@ class Interface: explainer = shap.KernelExplainer( get_masked_prediction, np.zeros((1, num_total_segments))) shap_values = explainer.shap_values(np.ones((1, num_total_segments)), nsamples=int( - self.num_shap*num_total_segments), silent=True) + self.num_shap * num_total_segments), silent=True) scores.append(input_component.get_interpretation_scores( raw_input[i], None, shap_values[0], masks=masks, tokens=tokens)) alternative_outputs.append([]) @@ -613,7 +613,7 @@ class Interface: # If running in a colab or not able to access localhost, automatically create a shareable link is_colab = utils.colab_check() - if is_colab or not(networking.url_ok(path_to_local_server)): + if is_colab or not (networking.url_ok(path_to_local_server)): share = True if is_colab: if debug: @@ -668,7 +668,7 @@ class Interface: display(IFrame(share_url, width=self.width, height=self.height)) else: display(IFrame(path_to_local_server, - width=self.width, height=self.height)) + width=self.width, height=self.height)) except ImportError: pass # IPython is not available so does not print inline. @@ -709,9 +709,9 @@ class Interface: self.server.shutdown() self.server_thread.join() print("Closing server running on port: {}".format(self.server_port)) - except AttributeError: # can't close if not running + except AttributeError: # can't close if not running pass - except OSError: # sometimes OSError is thrown when shutting down + except OSError: # sometimes OSError is thrown when shutting down pass def integrate(self, comet_ml=None, wandb=None, mlflow=None): @@ -736,7 +736,7 @@ class Interface: analytics_integration = "WandB" if self.share_url is not None: wandb.log({"Gradio panel": wandb.Html('')}) + str(self.width) + '" height="' + str(self.height) + '" frameBorder="0">')}) else: print( "The WandB integration requires you to `launch(share=True)` first.") @@ -749,8 +749,8 @@ class Interface: mlflow.log_param("Gradio Interface Local Link", self.local_url) if self.analytics_enabled and analytics_integration: - data = {'integration': analytics_integration} - utils.integration_analytics(data) + data = {'integration': analytics_integration} + utils.integration_analytics(data) def close_all(verbose=True): @@ -760,6 +760,6 @@ def close_all(verbose=True): def reset_all(): - warnings.warn("The `reset_all()` method has been renamed to `close_all()` " - "and will be deprecated. Please use `close_all()` instead.") + warnings.warn("The `reset_all()` method has been renamed to `close_all()` " + "and will be deprecated. Please use `close_all()` instead.") close_all()