added typing; deprecated capture_session; fixed test

This commit is contained in:
Abubakar Abid 2021-12-25 08:35:45 -06:00
parent 40adb8537b
commit bd433577b6
2 changed files with 103 additions and 37 deletions

View File

@ -3,10 +3,13 @@ This is the core file in the `gradio` package, and defines the Interface class,
interface using the input and output types.
"""
from __future__ import annotations
import copy
import csv
import getpass
import inspect
import threading
from typing import Callable
import markdown2
import numpy as np
import os
@ -14,6 +17,7 @@ import pkg_resources
import random
import sys
import time
from typing import Union, Any, List, Optional, Tuple, TYPE_CHECKING
import warnings
import webbrowser
import weakref
@ -21,9 +25,13 @@ import weakref
from gradio import networking, strings, utils, encryptor, queue
from gradio.external import load_interface, load_from_pipeline
from gradio.flagging import FlaggingCallback, CSVLogger
from gradio.inputs import get_input_instance
from gradio.inputs import get_input_instance, InputComponent
from gradio.interpretation import quantify_difference_in_label, get_regression_or_classification_value
from gradio.outputs import get_output_instance
from gradio.outputs import get_output_instance, OutputComponent
if TYPE_CHECKING: # Only import for type checking (always False at runtime).
import transformers
import flask
class Interface:
@ -44,7 +52,12 @@ class Interface:
return list(Interface.instances)
@classmethod
def load(cls, name: str, src: str = None, api_key: str = None, alias: str = None, **kwargs):
def load(cls,
name: str,
src: Optional[str] = None,
api_key: Optional[str] = None,
alias: Optional[str] = None,
**kwargs):
"""
Class method to construct an Interface from an external source repository, such as huggingface.
Parameters:
@ -64,7 +77,10 @@ class Interface:
return interface
@classmethod
def from_pipeline(cls, pipeline: "transformers.Pipeline", **kwargs):
def from_pipeline(
cls,
pipeline: transformers.Pipeline,
**kwargs):
"""
Class method to construct an Interface from a Hugging Face transformers.Pipeline.
pipeline (transformers.Pipeline):
@ -76,18 +92,46 @@ class Interface:
interface = cls(**kwargs)
return interface
def __init__(self, fn, inputs=None, outputs=None, verbose=None, examples=None,
examples_per_page=10, live=False, layout="unaligned", show_input=True, show_output=True,
capture_session=None, interpretation=None, num_shap=2.0, theme=None, repeat_outputs_per_model=True,
title=None, description=None, article=None, thumbnail=None,
css=None, height=None, width=None, allow_screenshot=True, allow_flagging=None, flagging_options=None,
encrypt=None, show_tips=None, flagging_dir="flagged", analytics_enabled=None, enable_queue=None, api_mode=None,
flagging_callback=CSVLogger()):
def __init__(
self,
fn: Callable,
inputs: str | InputComponent | List[str | InputComponent] = None,
outputs: str | OutputComponent | List[str | OutputComponent] = None,
verbose: bool = None,
examples: Union[List[List[Any]], str] = None,
examples_per_page: int = 10,
live: bool = False,
layout: str = "unaligned",
show_input: bool = True,
show_output: bool = True,
capture_session: Optional[bool] = None,
interpretation: Optional[bool] = None,
num_shap: int = 2.0,
theme: Optional[str] = None,
repeat_outputs_per_model: bool = True,
title: Optional[str] = None,
description: Optional[str] = None,
article: Optional[str] = None,
thumbnail: Optional[str] = None,
css: Optional[str] = None,
height=None,
width=None,
allow_screenshot: bool = True,
allow_flagging: bool = None,
flagging_options: List[str]=None,
encrypt=None,
show_tips=None,
flagging_dir: str = "flagged",
analytics_enabled: Optional[bool] = None,
enable_queue=None,
api_mode=None,
flagging_callback: FlaggingCallback = CSVLogger()
):
"""
Parameters:
fn (Callable): the function to wrap an interface around.
inputs (Union[str, List[Union[str, InputComponent]]]): a single Gradio input component, or list of Gradio input components. Components can either be passed as instantiated objects, or referred to by their string shortcuts. The number of input components should match the number of parameters in fn.
outputs (Union[str, List[Union[str, OutputComponent]]]): a single Gradio output component, or list of Gradio output components. Components can either be passed as instantiated objects, or referred to by their string shortcuts. The number of output components should match the number of values returned by fn.
inputs (Union[str, InputComponent, List[Union[str, InputComponent]]]): a single Gradio input component, or list of Gradio input components. Components can either be passed as instantiated objects, or referred to by their string shortcuts. The number of input components should match the number of parameters in fn.
outputs (Union[str, OutputComponent, List[Union[str, OutputComponent]]]): a single Gradio output component, or list of Gradio output components. Components can either be passed as instantiated objects, or referred to by their string shortcuts. The number of output components should match the number of values returned by fn.
verbose (bool): DEPRECATED. Whether to print detailed information during launch.
examples (Union[List[List[Any]], str]): sample inputs for the function; if provided, appears below the UI components and can be used to populate the interface. Should be nested list, in which the outer list consists of samples and each inner list consists of an input corresponding to each input component. A string path to a directory of examples can also be provided. If there are multiple input components and a directory is provided, a log.csv file must be present in the directory to link corresponding inputs.
examples_per_page (int): If examples are provided, how many to display per page.
@ -343,7 +387,12 @@ class Interface:
config["examples"] = self.examples
return config
def run_prediction(self, processed_input, return_duration=False, called_directly=False):
def run_prediction(
self,
processed_input: List[Any],
return_duration: bool = False,
called_directly: bool = False
) -> List[Any] | Tuple[List[Any], List[float]]:
"""
This is the method that actually runs the prediction function with the given (processed) inputs.
Parameters:
@ -384,7 +433,10 @@ class Interface:
else:
return predictions
def process(self, raw_input):
def process(
self,
raw_input: List[Any]
) -> Tuple[List[Any], List[float]]:
"""
Parameters:
raw_input: a list of raw inputs to process and apply the prediction(s) on.
@ -400,7 +452,10 @@ class Interface:
for i, output_component in enumerate(self.output_components)]
return processed_output, durations
def interpret(self, raw_input):
def interpret(
self,
raw_input: List[Any]
) -> List[Any]:
"""
Runs the interpretation command for the machine learning model. Handles both the "default" out-of-the-box
interpretation for a certain set of UI component types, as well as the custom interpretation case.
@ -515,7 +570,11 @@ class Interface:
interpretation = [interpretation]
return interpretation, []
def run_until_interrupted(self, thread, path_to_local_server):
def run_until_interrupted(
self,
thread: threading.Thread,
path_to_local_server: str
):
try:
while True:
time.sleep(0.5)
@ -543,11 +602,25 @@ class Interface:
print("PASSED")
continue
def launch(self, inline=None, inbrowser=None, share=False, debug=False,
auth=None, auth_message=None, private_endpoint=None,
prevent_thread_lock=False, show_error=True, server_name=None,
server_port=None, show_tips=False, enable_queue=False,
height=500, width=900, encrypt=False):
def launch(
self,
inline: bool = None,
inbrowser: bool = None,
share: bool = False,
debug: bool = False,
auth: Optional[Callable | Tuple[str, str] | List[Tuple[str, str]]] = None,
auth_message: Optional[str] = None,
private_endpoint: Optional[str] = None,
prevent_thread_lock: bool = False,
show_error: bool = True,
server_name: Optional[str] = None,
server_port: Optional[int] = None,
show_tips: bool = False,
enable_queue: bool = False,
height: int = 500,
width: int = 900,
encrypt: bool = False
) -> Tuple[flask.Flask, str, str]:
"""
Launches the webserver that serves the UI for the interface.
Parameters:
@ -564,13 +637,13 @@ class Interface:
server_name (str): to make app accessible on local network, set this to "0.0.0.0".
show_tips (bool): if True, will occasionally show tips about new Gradio features
enable_queue (bool): if True, inference requests will be served through a queue instead of with parallel threads. Required for longer inference times (> 1min) to prevent timeout.
width (int): The width in pixels of the <iframe> element containing the interface (used if inline=True)
height (int): The height in pixels of the <iframe> element containing the interface (used if inline=True)
encrypt (bool): If True, flagged data will be encrypted by key provided by creator at launch
Returns:
app (flask.Flask): Flask app object
path_to_local_server (str): Locally accessible link
share_url (str): Publicly accessible link (if share=True)
width (bool): The width of the <iframe> element containing the interface (used if inline=True)
height (bool): The height of the <iframe> element containing the interface (used if inline=True)
encrypt (bool): If True, flagged data will be encrypted by key provided by creator at launch
"""
# Set up local flask server
config = self.get_config_file()
@ -702,7 +775,10 @@ class Interface:
return app, path_to_local_server, share_url
def close(self, verbose=True):
def close(
self,
verbose: bool = True
):
"""
Closes the Interface that was launched. This will close the server and free the port.
"""

View File

@ -189,15 +189,5 @@ class TestInterface(unittest.TestCase):
interface.integrate(mlflow=mlflow)
mock_post.assert_called_once()
def test_capture_session(self):
interface = Interface(lambda x: x, "textbox", "label", capture_session=True, interpretation=lambda x: 0)
interface.session = (mock.MagicMock(), mock.MagicMock())
interface.interpret(["quickest brown fox"])
interface.session[0].as_default.assert_called_once()
interface.session[1].as_default.assert_called_once()
interface.run_prediction(["quickest brown fox"])
self.assertEqual(interface.session[0].as_default.call_count, 2)
self.assertEqual(interface.session[1].as_default.call_count, 2)
if __name__ == '__main__':
unittest.main()