finished adding type hints in interface.py

This commit is contained in:
Abubakar Abid 2021-12-27 12:34:05 -06:00
parent 1758b9e1fc
commit 5a0caeaf7a

View File

@ -17,11 +17,10 @@ import warnings
import webbrowser
import weakref
from gradio import encryptor, networking, queue, strings, utils # type: ignore
from gradio import encryptor, interpretation, networking, queue, strings, utils # type: ignore
from gradio.external import load_interface, load_from_pipeline # type: ignore
from gradio.flagging import FlaggingCallback, CSVLogger # type: ignore
from gradio.inputs import get_input_instance, InputComponent # type: ignore
from gradio.interpretation import run_interpret # type: ignore
from gradio.outputs import get_output_instance, OutputComponent # type: ignore
if TYPE_CHECKING: # Only import for type checking (is False at runtime).
@ -40,7 +39,7 @@ class Interface:
instances: weakref.WeakSet = weakref.WeakSet()
@classmethod
def get_instances(cls):
def get_instances(cls) -> List[Interface]:
"""
:return: list of all current instances.
"""
@ -52,7 +51,7 @@ class Interface:
src: Optional[str] = None,
api_key: Optional[str] = None,
alias: Optional[str] = None,
**kwargs):
**kwargs) -> Interface:
"""
Class method to construct an Interface from an external source repository, such as huggingface.
Parameters:
@ -75,7 +74,7 @@ class Interface:
def from_pipeline(
cls,
pipeline: transformers.Pipeline,
**kwargs):
**kwargs) -> Interface:
"""
Class method to construct an Interface from a Hugging Face transformers.Pipeline.
pipeline (transformers.Pipeline):
@ -381,13 +380,13 @@ class Interface:
self,
raw_input: List[Any]
) -> List[Any]:
return run_interpret(self, raw_input)
return interpretation.run_interpret(self, raw_input)
def run_until_interrupted(
self,
thread: threading.Thread,
path_to_local_server: str
):
) -> None:
try:
while True:
time.sleep(0.5)
@ -399,7 +398,7 @@ class Interface:
if self.enable_queue:
queue.close()
def test_launch(self):
def test_launch(self) -> None:
for predict_fn in self.predict:
print("Test launch: {}()...".format(predict_fn.__name__), end=' ')
@ -537,10 +536,8 @@ class Interface:
# Open a browser tab with the interface.
if inbrowser:
if share:
webbrowser.open(share_url)
else:
webbrowser.open(path_to_local_server)
link = share_url if share else path_to_local_server
webbrowser.open(link)
# Check if running in a Python notebook in which case, display inline
if inline is None:
@ -591,20 +588,26 @@ class Interface:
def close(
self,
verbose: bool = True
):
) -> None:
"""
Closes the Interface that was launched. This will close the server and free the port.
"""
try:
self.server.shutdown()
self.server_thread.join()
print("Closing server running on port: {}".format(self.server_port))
if verbose:
print("Closing server running on port: {}".format(self.server_port))
except AttributeError: # can't close if not running
pass
except OSError: # sometimes OSError is thrown when shutting down
pass
def integrate(self, comet_ml=None, wandb=None, mlflow=None):
def integrate(
self,
comet_ml=None,
wandb=None,
mlflow=None
) -> None:
"""
A catch-all method for integrating with other libraries. Should be run after launch()
Parameters:
@ -643,13 +646,13 @@ class Interface:
utils.integration_analytics(data)
def close_all(verbose=True):
def close_all(verbose: bool = True) -> None:
# Tries to close all running interfaces, but method is a little flaky.
for io in Interface.get_instances():
io.close(verbose)
def reset_all():
def reset_all() -> None:
warnings.warn("The `reset_all()` method has been renamed to `close_all()` "
"and will be deprecated. Please use `close_all()` instead.")
close_all()