Introduce example type hints

This commit is contained in:
Ömer Faruk Özdemir 2021-12-24 16:12:00 +03:00
parent d9f302ab82
commit 831e758fdb

View File

@ -40,7 +40,7 @@ class Interface:
return list(Interface.instances)
@classmethod
def load(cls, name, src=None, api_key=None, alias=None, **kwargs):
def load(cls, name: str, src: str = None, api_key: str = None, alias: str = None, **kwargs):
"""
Class method to construct an Interface from an external source repository, such as huggingface.
Parameters:
@ -52,15 +52,15 @@ class Interface:
(gradio.Interface): a Gradio Interface object for the given model
"""
interface_info = load_interface(name, src, api_key, alias)
# create a dictionary of kwargs without overwriting the original interface_info dict because it is mutable
# and that can cause some issues since the internal prediction function may rely on the original interface_info dict
kwargs = dict(interface_info, **kwargs)
# create a dictionary of kwargs without overwriting the original interface_info dict because it is mutable
# and that can cause some issues since the internal prediction function may rely on the original interface_info dict
kwargs = dict(interface_info, **kwargs)
interface = cls(**kwargs)
interface.api_mode = True # set api mode to true so that the interface will not preprocess/postprocess
return interface
@classmethod
def from_pipeline(cls, pipeline, **kwargs):
def from_pipeline(cls, pipeline: "transformers.Pipeline", **kwargs):
"""
Class method to construct an Interface from a Hugging Face transformers.Pipeline.
pipeline (transformers.Pipeline):
@ -68,7 +68,7 @@ class Interface:
(gradio.Interface): a Gradio Interface object from the given Pipeline
"""
interface_info = load_from_pipeline(pipeline)
kwargs = dict(interface_info, **kwargs)
kwargs = dict(interface_info, **kwargs)
interface = cls(**kwargs)
return interface
@ -76,7 +76,7 @@ class Interface:
examples_per_page=10, live=False, layout="unaligned", show_input=True, show_output=True,
capture_session=None, interpretation=None, num_shap=2.0, theme=None, repeat_outputs_per_model=True,
title=None, description=None, article=None, thumbnail=None,
css=None, height=500, width=900, allow_screenshot=True, allow_flagging=None, flagging_options=None,
css=None, height=500, width=900, allow_screenshot=True, allow_flagging=None, flagging_options=None,
encrypt=False, show_tips=None, flagging_dir="flagged", analytics_enabled=None, enable_queue=None, api_mode=None):
"""
Parameters:
@ -129,7 +129,7 @@ class Interface:
self.predict_durations = [[0, 0]] * len(fn)
self.function_names = [func.__name__ for func in fn]
self.__name__ = ", ".join(self.function_names)
if verbose is not None:
warnings.warn("The `verbose` parameter in the `Interface` is deprecated and has no effect.")
@ -140,7 +140,7 @@ class Interface:
self.show_output = show_output
self.flag_hash = random.getrandbits(32)
self.capture_session = capture_session
if capture_session is not None:
warnings.warn("The `capture_session` parameter in the `Interface` will be deprecated in the near future.")
@ -174,12 +174,12 @@ class Interface:
"Examples argument must either be a directory or a nested list, where each sublist represents a set of inputs.")
self.num_shap = num_shap
self.examples_per_page = examples_per_page
self.simple_server = None
self.allow_screenshot = allow_screenshot
# For allow_flagging and analytics_enabled: (1) first check for parameter, (2) check for environment variable, (3) default to True
self.allow_flagging = allow_flagging if allow_flagging is not None else os.getenv("GRADIO_ALLOW_FLAGGING", "True")=="True"
self.analytics_enabled = analytics_enabled if analytics_enabled is not None else os.getenv("GRADIO_ANALYTICS_ENABLED", "True")=="True"
self.allow_flagging = allow_flagging if allow_flagging is not None else os.getenv("GRADIO_ALLOW_FLAGGING", "True") == "True"
self.analytics_enabled = analytics_enabled if analytics_enabled is not None else os.getenv("GRADIO_ANALYTICS_ENABLED", "True") == "True"
self.flagging_options = flagging_options
self.flagging_dir = flagging_dir
self.encrypt = encrypt
@ -188,13 +188,13 @@ class Interface:
self.share_url = None
self.local_url = None
self.ip_address = networking.get_local_ip_address()
if show_tips is not None:
warnings.warn("The `show_tips` parameter in the `Interface` is deprecated. Please use the `show_tips` parameter in `launch()` instead")
self.requires_permissions = any(
[component.requires_permissions for component in self.input_components])
self.enable_queue = enable_queue
if self.enable_queue is not None:
warnings.warn("The `enable_queue` parameter in the `Interface` will be deprecated. Please use the `enable_queue` parameter in `launch()` instead")
@ -207,7 +207,7 @@ class Interface:
try:
import tensorflow as tf
self.session = tf.get_default_graph(), \
tf.keras.backend.get_session()
tf.keras.backend.get_session()
except (ImportError, AttributeError):
# If they are using TF >= 2.0 or don't have TF,
# just ignore this.
@ -240,7 +240,7 @@ class Interface:
if self.api_mode: # skip the preprocessing/postprocessing if sending to a remote API
output = self.run_prediction(params, called_directly=True)
else:
output, _ = self.process(params)
output, _ = self.process(params)
return output[0] if len(output) == 1 else output
def __str__(self):
@ -249,7 +249,7 @@ class Interface:
def __repr__(self):
repr = "Gradio Interface for: {}".format(
", ".join(fn.__name__ for fn in self.predict))
repr += "\n" + "-"*len(repr)
repr += "\n" + "-" * len(repr)
repr += "\ninputs:"
for component in self.input_components:
repr += "\n|-{}".format(str(component))
@ -296,7 +296,7 @@ class Interface:
function_index = i // outputs_per_function
component_index = i - function_index * outputs_per_function
ret_name = "Output " + \
str(component_index + 1) if outputs_per_function > 1 else "Output"
str(component_index + 1) if outputs_per_function > 1 else "Output"
if iface["label"] is None:
iface["label"] = ret_name
if len(self.predict) > 1:
@ -344,7 +344,7 @@ class Interface:
"""
if self.api_mode: # Serialize the input
processed_input = [input_component.serialize(processed_input[i], called_directly)
for i, input_component in enumerate(self.input_components)]
for i, input_component in enumerate(self.input_components)]
predictions = []
durations = []
output_component_counter = 0
@ -395,7 +395,7 @@ class Interface:
for i, input_component in enumerate(self.input_components)]
predictions, durations = self.run_prediction(
processed_input, return_duration=True)
processed_output = [output_component.postprocess(predictions[i]) if predictions[i] is not None else None
processed_output = [output_component.postprocess(predictions[i]) if predictions[i] is not None else None
for i, output_component in enumerate(self.output_components)]
return processed_output, durations
@ -467,7 +467,7 @@ class Interface:
raise ValueError(
"The package `shap` is required for this interpretation method. Try: `pip install shap`")
input_component = self.input_components[i]
if not(input_component.interpret_by_tokens):
if not (input_component.interpret_by_tokens):
raise ValueError(
"Input component {} does not support `shap` interpretation".format(input_component))
@ -494,7 +494,7 @@ class Interface:
explainer = shap.KernelExplainer(
get_masked_prediction, np.zeros((1, num_total_segments)))
shap_values = explainer.shap_values(np.ones((1, num_total_segments)), nsamples=int(
self.num_shap*num_total_segments), silent=True)
self.num_shap * num_total_segments), silent=True)
scores.append(input_component.get_interpretation_scores(
raw_input[i], None, shap_values[0], masks=masks, tokens=tokens))
alternative_outputs.append([])
@ -613,7 +613,7 @@ class Interface:
# If running in a colab or not able to access localhost, automatically create a shareable link
is_colab = utils.colab_check()
if is_colab or not(networking.url_ok(path_to_local_server)):
if is_colab or not (networking.url_ok(path_to_local_server)):
share = True
if is_colab:
if debug:
@ -668,7 +668,7 @@ class Interface:
display(IFrame(share_url, width=self.width, height=self.height))
else:
display(IFrame(path_to_local_server,
width=self.width, height=self.height))
width=self.width, height=self.height))
except ImportError:
pass # IPython is not available so does not print inline.
@ -709,9 +709,9 @@ class Interface:
self.server.shutdown()
self.server_thread.join()
print("Closing server running on port: {}".format(self.server_port))
except AttributeError: # can't close if not running
except AttributeError: # can't close if not running
pass
except OSError: # sometimes OSError is thrown when shutting down
except OSError: # sometimes OSError is thrown when shutting down
pass
def integrate(self, comet_ml=None, wandb=None, mlflow=None):
@ -736,7 +736,7 @@ class Interface:
analytics_integration = "WandB"
if self.share_url is not None:
wandb.log({"Gradio panel": wandb.Html('<iframe src="' + self.share_url + '" width="' +
str(self.width) + '" height="' + str(self.height) + '" frameBorder="0"></iframe>')})
str(self.width) + '" height="' + str(self.height) + '" frameBorder="0"></iframe>')})
else:
print(
"The WandB integration requires you to `launch(share=True)` first.")
@ -749,8 +749,8 @@ class Interface:
mlflow.log_param("Gradio Interface Local Link",
self.local_url)
if self.analytics_enabled and analytics_integration:
data = {'integration': analytics_integration}
utils.integration_analytics(data)
data = {'integration': analytics_integration}
utils.integration_analytics(data)
def close_all(verbose=True):
@ -760,6 +760,6 @@ def close_all(verbose=True):
def reset_all():
warnings.warn("The `reset_all()` method has been renamed to `close_all()` "
"and will be deprecated. Please use `close_all()` instead.")
warnings.warn("The `reset_all()` method has been renamed to `close_all()` "
"and will be deprecated. Please use `close_all()` instead.")
close_all()