mirror of
https://github.com/gradio-app/gradio.git
synced 2025-01-30 11:00:11 +08:00
Introduce example type hints
This commit is contained in:
parent
d9f302ab82
commit
831e758fdb
@ -40,7 +40,7 @@ class Interface:
|
||||
return list(Interface.instances)
|
||||
|
||||
@classmethod
|
||||
def load(cls, name, src=None, api_key=None, alias=None, **kwargs):
|
||||
def load(cls, name: str, src: str = None, api_key: str = None, alias: str = None, **kwargs):
|
||||
"""
|
||||
Class method to construct an Interface from an external source repository, such as huggingface.
|
||||
Parameters:
|
||||
@ -52,15 +52,15 @@ class Interface:
|
||||
(gradio.Interface): a Gradio Interface object for the given model
|
||||
"""
|
||||
interface_info = load_interface(name, src, api_key, alias)
|
||||
# create a dictionary of kwargs without overwriting the original interface_info dict because it is mutable
|
||||
# and that can cause some issues since the internal prediction function may rely on the original interface_info dict
|
||||
kwargs = dict(interface_info, **kwargs)
|
||||
# create a dictionary of kwargs without overwriting the original interface_info dict because it is mutable
|
||||
# and that can cause some issues since the internal prediction function may rely on the original interface_info dict
|
||||
kwargs = dict(interface_info, **kwargs)
|
||||
interface = cls(**kwargs)
|
||||
interface.api_mode = True # set api mode to true so that the interface will not preprocess/postprocess
|
||||
return interface
|
||||
|
||||
@classmethod
|
||||
def from_pipeline(cls, pipeline, **kwargs):
|
||||
def from_pipeline(cls, pipeline: "transformers.Pipeline", **kwargs):
|
||||
"""
|
||||
Class method to construct an Interface from a Hugging Face transformers.Pipeline.
|
||||
pipeline (transformers.Pipeline):
|
||||
@ -68,7 +68,7 @@ class Interface:
|
||||
(gradio.Interface): a Gradio Interface object from the given Pipeline
|
||||
"""
|
||||
interface_info = load_from_pipeline(pipeline)
|
||||
kwargs = dict(interface_info, **kwargs)
|
||||
kwargs = dict(interface_info, **kwargs)
|
||||
interface = cls(**kwargs)
|
||||
return interface
|
||||
|
||||
@ -76,7 +76,7 @@ class Interface:
|
||||
examples_per_page=10, live=False, layout="unaligned", show_input=True, show_output=True,
|
||||
capture_session=None, interpretation=None, num_shap=2.0, theme=None, repeat_outputs_per_model=True,
|
||||
title=None, description=None, article=None, thumbnail=None,
|
||||
css=None, height=500, width=900, allow_screenshot=True, allow_flagging=None, flagging_options=None,
|
||||
css=None, height=500, width=900, allow_screenshot=True, allow_flagging=None, flagging_options=None,
|
||||
encrypt=False, show_tips=None, flagging_dir="flagged", analytics_enabled=None, enable_queue=None, api_mode=None):
|
||||
"""
|
||||
Parameters:
|
||||
@ -129,7 +129,7 @@ class Interface:
|
||||
self.predict_durations = [[0, 0]] * len(fn)
|
||||
self.function_names = [func.__name__ for func in fn]
|
||||
self.__name__ = ", ".join(self.function_names)
|
||||
|
||||
|
||||
if verbose is not None:
|
||||
warnings.warn("The `verbose` parameter in the `Interface` is deprecated and has no effect.")
|
||||
|
||||
@ -140,7 +140,7 @@ class Interface:
|
||||
self.show_output = show_output
|
||||
self.flag_hash = random.getrandbits(32)
|
||||
self.capture_session = capture_session
|
||||
|
||||
|
||||
if capture_session is not None:
|
||||
warnings.warn("The `capture_session` parameter in the `Interface` will be deprecated in the near future.")
|
||||
|
||||
@ -174,12 +174,12 @@ class Interface:
|
||||
"Examples argument must either be a directory or a nested list, where each sublist represents a set of inputs.")
|
||||
self.num_shap = num_shap
|
||||
self.examples_per_page = examples_per_page
|
||||
|
||||
|
||||
self.simple_server = None
|
||||
self.allow_screenshot = allow_screenshot
|
||||
# For allow_flagging and analytics_enabled: (1) first check for parameter, (2) check for environment variable, (3) default to True
|
||||
self.allow_flagging = allow_flagging if allow_flagging is not None else os.getenv("GRADIO_ALLOW_FLAGGING", "True")=="True"
|
||||
self.analytics_enabled = analytics_enabled if analytics_enabled is not None else os.getenv("GRADIO_ANALYTICS_ENABLED", "True")=="True"
|
||||
self.allow_flagging = allow_flagging if allow_flagging is not None else os.getenv("GRADIO_ALLOW_FLAGGING", "True") == "True"
|
||||
self.analytics_enabled = analytics_enabled if analytics_enabled is not None else os.getenv("GRADIO_ANALYTICS_ENABLED", "True") == "True"
|
||||
self.flagging_options = flagging_options
|
||||
self.flagging_dir = flagging_dir
|
||||
self.encrypt = encrypt
|
||||
@ -188,13 +188,13 @@ class Interface:
|
||||
self.share_url = None
|
||||
self.local_url = None
|
||||
self.ip_address = networking.get_local_ip_address()
|
||||
|
||||
|
||||
if show_tips is not None:
|
||||
warnings.warn("The `show_tips` parameter in the `Interface` is deprecated. Please use the `show_tips` parameter in `launch()` instead")
|
||||
|
||||
self.requires_permissions = any(
|
||||
[component.requires_permissions for component in self.input_components])
|
||||
|
||||
|
||||
self.enable_queue = enable_queue
|
||||
if self.enable_queue is not None:
|
||||
warnings.warn("The `enable_queue` parameter in the `Interface` will be deprecated. Please use the `enable_queue` parameter in `launch()` instead")
|
||||
@ -207,7 +207,7 @@ class Interface:
|
||||
try:
|
||||
import tensorflow as tf
|
||||
self.session = tf.get_default_graph(), \
|
||||
tf.keras.backend.get_session()
|
||||
tf.keras.backend.get_session()
|
||||
except (ImportError, AttributeError):
|
||||
# If they are using TF >= 2.0 or don't have TF,
|
||||
# just ignore this.
|
||||
@ -240,7 +240,7 @@ class Interface:
|
||||
if self.api_mode: # skip the preprocessing/postprocessing if sending to a remote API
|
||||
output = self.run_prediction(params, called_directly=True)
|
||||
else:
|
||||
output, _ = self.process(params)
|
||||
output, _ = self.process(params)
|
||||
return output[0] if len(output) == 1 else output
|
||||
|
||||
def __str__(self):
|
||||
@ -249,7 +249,7 @@ class Interface:
|
||||
def __repr__(self):
|
||||
repr = "Gradio Interface for: {}".format(
|
||||
", ".join(fn.__name__ for fn in self.predict))
|
||||
repr += "\n" + "-"*len(repr)
|
||||
repr += "\n" + "-" * len(repr)
|
||||
repr += "\ninputs:"
|
||||
for component in self.input_components:
|
||||
repr += "\n|-{}".format(str(component))
|
||||
@ -296,7 +296,7 @@ class Interface:
|
||||
function_index = i // outputs_per_function
|
||||
component_index = i - function_index * outputs_per_function
|
||||
ret_name = "Output " + \
|
||||
str(component_index + 1) if outputs_per_function > 1 else "Output"
|
||||
str(component_index + 1) if outputs_per_function > 1 else "Output"
|
||||
if iface["label"] is None:
|
||||
iface["label"] = ret_name
|
||||
if len(self.predict) > 1:
|
||||
@ -344,7 +344,7 @@ class Interface:
|
||||
"""
|
||||
if self.api_mode: # Serialize the input
|
||||
processed_input = [input_component.serialize(processed_input[i], called_directly)
|
||||
for i, input_component in enumerate(self.input_components)]
|
||||
for i, input_component in enumerate(self.input_components)]
|
||||
predictions = []
|
||||
durations = []
|
||||
output_component_counter = 0
|
||||
@ -395,7 +395,7 @@ class Interface:
|
||||
for i, input_component in enumerate(self.input_components)]
|
||||
predictions, durations = self.run_prediction(
|
||||
processed_input, return_duration=True)
|
||||
processed_output = [output_component.postprocess(predictions[i]) if predictions[i] is not None else None
|
||||
processed_output = [output_component.postprocess(predictions[i]) if predictions[i] is not None else None
|
||||
for i, output_component in enumerate(self.output_components)]
|
||||
return processed_output, durations
|
||||
|
||||
@ -467,7 +467,7 @@ class Interface:
|
||||
raise ValueError(
|
||||
"The package `shap` is required for this interpretation method. Try: `pip install shap`")
|
||||
input_component = self.input_components[i]
|
||||
if not(input_component.interpret_by_tokens):
|
||||
if not (input_component.interpret_by_tokens):
|
||||
raise ValueError(
|
||||
"Input component {} does not support `shap` interpretation".format(input_component))
|
||||
|
||||
@ -494,7 +494,7 @@ class Interface:
|
||||
explainer = shap.KernelExplainer(
|
||||
get_masked_prediction, np.zeros((1, num_total_segments)))
|
||||
shap_values = explainer.shap_values(np.ones((1, num_total_segments)), nsamples=int(
|
||||
self.num_shap*num_total_segments), silent=True)
|
||||
self.num_shap * num_total_segments), silent=True)
|
||||
scores.append(input_component.get_interpretation_scores(
|
||||
raw_input[i], None, shap_values[0], masks=masks, tokens=tokens))
|
||||
alternative_outputs.append([])
|
||||
@ -613,7 +613,7 @@ class Interface:
|
||||
|
||||
# If running in a colab or not able to access localhost, automatically create a shareable link
|
||||
is_colab = utils.colab_check()
|
||||
if is_colab or not(networking.url_ok(path_to_local_server)):
|
||||
if is_colab or not (networking.url_ok(path_to_local_server)):
|
||||
share = True
|
||||
if is_colab:
|
||||
if debug:
|
||||
@ -668,7 +668,7 @@ class Interface:
|
||||
display(IFrame(share_url, width=self.width, height=self.height))
|
||||
else:
|
||||
display(IFrame(path_to_local_server,
|
||||
width=self.width, height=self.height))
|
||||
width=self.width, height=self.height))
|
||||
except ImportError:
|
||||
pass # IPython is not available so does not print inline.
|
||||
|
||||
@ -709,9 +709,9 @@ class Interface:
|
||||
self.server.shutdown()
|
||||
self.server_thread.join()
|
||||
print("Closing server running on port: {}".format(self.server_port))
|
||||
except AttributeError: # can't close if not running
|
||||
except AttributeError: # can't close if not running
|
||||
pass
|
||||
except OSError: # sometimes OSError is thrown when shutting down
|
||||
except OSError: # sometimes OSError is thrown when shutting down
|
||||
pass
|
||||
|
||||
def integrate(self, comet_ml=None, wandb=None, mlflow=None):
|
||||
@ -736,7 +736,7 @@ class Interface:
|
||||
analytics_integration = "WandB"
|
||||
if self.share_url is not None:
|
||||
wandb.log({"Gradio panel": wandb.Html('<iframe src="' + self.share_url + '" width="' +
|
||||
str(self.width) + '" height="' + str(self.height) + '" frameBorder="0"></iframe>')})
|
||||
str(self.width) + '" height="' + str(self.height) + '" frameBorder="0"></iframe>')})
|
||||
else:
|
||||
print(
|
||||
"The WandB integration requires you to `launch(share=True)` first.")
|
||||
@ -749,8 +749,8 @@ class Interface:
|
||||
mlflow.log_param("Gradio Interface Local Link",
|
||||
self.local_url)
|
||||
if self.analytics_enabled and analytics_integration:
|
||||
data = {'integration': analytics_integration}
|
||||
utils.integration_analytics(data)
|
||||
data = {'integration': analytics_integration}
|
||||
utils.integration_analytics(data)
|
||||
|
||||
|
||||
def close_all(verbose=True):
|
||||
@ -760,6 +760,6 @@ def close_all(verbose=True):
|
||||
|
||||
|
||||
def reset_all():
|
||||
warnings.warn("The `reset_all()` method has been renamed to `close_all()` "
|
||||
"and will be deprecated. Please use `close_all()` instead.")
|
||||
warnings.warn("The `reset_all()` method has been renamed to `close_all()` "
|
||||
"and will be deprecated. Please use `close_all()` instead.")
|
||||
close_all()
|
||||
|
Loading…
Reference in New Issue
Block a user