More typing! (#2906)

* started pathlib

* blocks.py

* more changes

* fixes

* typing

* formatting

* typing

* renaming files

* changelog

* script

* changelog

* lint

* routes

* renamed

* state

* formatting

* state

* type check script

* remove strictness

* switched to pyright

* switched to pyright

* fixed flaky tests

* fixed test xray

* fixed load test

* fixed blocks tests

* formatting

* fixed components test

* uncomment tests

* fixed interpretation tests

* formatting

* last tests hopefully

* argh lint

* component

* fixed based on review

* refactor

* components.py t yping

* components.py

* formatting

* lint script

* merge

* merge

* lint

* pathlib

* lint

* events too

* lint script

* fixing tests

* lint

* examples

* serializing

* more files

* formatting

* flagging.py

* added to lint script

* fixed tab

* interface.py

* attempt fix

* refactoring interface

* interface refactor

* formatting

* fix for live interfaces

* lint

* mix

* mix

* serialize fix

* formatting

* all demos queue

* networking

* added type check

* processing_utils

* more typing

* formatting

* type ignored processing utils

* s

* tunneling

* add interpretation

* more typing

* queuing

* serializing

* undo interpretation

* routes.py

* formatting

* component type

* addressed review

* lint

* typing

* documentation

* fixing pydantic

* routes

* fixed typing in routes

* fix tests
This commit is contained in:
Abubakar Abid 2023-01-03 13:13:11 -05:00 committed by GitHub
parent 6a6e9175e1
commit 38c64a5b0e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 196 additions and 166 deletions

View File

@ -1435,6 +1435,7 @@ class Image(
assert isinstance(x, dict)
x, mask = x["image"], x["mask"]
assert isinstance(x, str)
im = processing_utils.decode_base64_to_image(x)
with warnings.catch_warnings():
warnings.simplefilter("ignore")

View File

@ -1,3 +1,6 @@
"""Pydantic data models and other dataclasses. This is the only file that uses Optional[]
typing syntax instead of | None syntax to work with pydantic"""
from enum import Enum, auto
from typing import Any, Dict, List, Optional, Union
@ -35,8 +38,8 @@ class Estimation(BaseModel):
queue_size: int
avg_event_process_time: Optional[float]
avg_event_concurrent_process_time: Optional[float]
rank_eta: Optional[int] = None
queue_eta: int
rank_eta: Optional[float] = None
queue_eta: float
class ProgressUnit(BaseModel):

View File

@ -1,5 +1,9 @@
"""Contains methods that generate documentation for Gradio functions and classes."""
from __future__ import annotations
import inspect
from typing import Callable, Dict, List, Optional, Tuple
from typing import Callable, Dict, List, Tuple
classes_to_document = {}
documentation_group = None
@ -30,7 +34,7 @@ def document(*fns):
return inner_doc
def document_fn(fn: Callable) -> Tuple[str, List[Dict], Dict, Optional[str]]:
def document_fn(fn: Callable) -> Tuple[str, List[Dict], Dict, str | None]:
"""
Generates documentation for any function.
Parameters:

View File

@ -3,16 +3,12 @@ Ways to transform interfaces to produce new interfaces
"""
import asyncio
import warnings
from typing import TYPE_CHECKING, List
import gradio
from gradio.documentation import document, set_documentation_group
set_documentation_group("mix_interface")
if TYPE_CHECKING: # Only import for type checking (to avoid circular imports).
from gradio.components import IOComponent
@document()
class Parallel(gradio.Interface):
@ -32,7 +28,7 @@ class Parallel(gradio.Interface):
Returns:
an Interface object comparing the given models
"""
outputs: List[IOComponent] = []
outputs = []
for interface in interfaces:
if not (isinstance(interface, gradio.Interface)):
@ -44,7 +40,7 @@ class Parallel(gradio.Interface):
async def parallel_fn(*args):
return_values_with_durations = await asyncio.gather(
*[interface.call_function(0, args) for interface in interfaces]
*[interface.call_function(0, list(args)) for interface in interfaces]
)
return_values = [rv["prediction"] for rv in return_values_with_durations]
combined_list = []
@ -97,7 +93,7 @@ class Series(gradio.Interface):
]
# run all of predictions sequentially
data = (await interface.call_function(0, data))["prediction"]
data = (await interface.call_function(0, list(data)))["prediction"]
if len(interface.output_components) == 1:
data = [data]
@ -110,7 +106,7 @@ class Series(gradio.Interface):
)
]
if len(interface.output_components) == 1:
if len(interface.output_components) == 1: # type: ignore
return data[0]
return data

View File

@ -11,7 +11,6 @@ import time
import warnings
from typing import TYPE_CHECKING, Tuple
import fastapi
import requests
import uvicorn
@ -69,7 +68,7 @@ def get_first_available_port(initial: int, final: int) -> int:
)
def configure_app(app: fastapi.FastAPI, blocks: Blocks) -> fastapi.FastAPI:
def configure_app(app: App, blocks: Blocks) -> App:
auth = blocks.auth
if auth is not None:
if not callable(auth):
@ -183,3 +182,4 @@ def url_ok(url: str) -> bool:
time.sleep(0.500)
except (ConnectionError, requests.exceptions.ConnectionError):
return False
return False

View File

@ -8,10 +8,10 @@ from typing import TYPE_CHECKING, Dict
from gradio import components
if TYPE_CHECKING: # Only import for type checking (is False at runtime).
import transformers
from transformers import pipelines
def load_from_pipeline(pipeline: transformers.Pipeline) -> Dict:
def load_from_pipeline(pipeline: pipelines.base.Pipeline) -> Dict:
"""
Gets the appropriate Interface kwargs for a given Hugging Face transformers.Pipeline.
pipeline (transformers.Pipeline): the transformers.Pipeline from which to create an interface
@ -20,17 +20,18 @@ def load_from_pipeline(pipeline: transformers.Pipeline) -> Dict:
"""
try:
import transformers
from transformers import pipelines
except ImportError:
raise ImportError(
"transformers not installed. Please try `pip install transformers`"
)
if not isinstance(pipeline, transformers.Pipeline):
if not isinstance(pipeline, pipelines.base.Pipeline):
raise ValueError("pipeline must be a transformers.Pipeline")
# Handle the different pipelines. The has_attr() checks to make sure the pipeline exists in the
# version of the transformers library that the user has installed.
if hasattr(transformers, "AudioClassificationPipeline") and isinstance(
pipeline, transformers.AudioClassificationPipeline
pipeline, pipelines.audio_classification.AudioClassificationPipeline
):
pipeline_info = {
"inputs": components.Audio(
@ -41,7 +42,8 @@ def load_from_pipeline(pipeline: transformers.Pipeline) -> Dict:
"postprocess": lambda r: {i["label"].split(", ")[0]: i["score"] for i in r},
}
elif hasattr(transformers, "AutomaticSpeechRecognitionPipeline") and isinstance(
pipeline, transformers.AutomaticSpeechRecognitionPipeline
pipeline,
pipelines.automatic_speech_recognition.AutomaticSpeechRecognitionPipeline,
):
pipeline_info = {
"inputs": components.Audio(
@ -52,7 +54,7 @@ def load_from_pipeline(pipeline: transformers.Pipeline) -> Dict:
"postprocess": lambda r: r["text"],
}
elif hasattr(transformers, "FeatureExtractionPipeline") and isinstance(
pipeline, transformers.FeatureExtractionPipeline
pipeline, pipelines.feature_extraction.FeatureExtractionPipeline
):
pipeline_info = {
"inputs": components.Textbox(label="Input"),
@ -61,7 +63,7 @@ def load_from_pipeline(pipeline: transformers.Pipeline) -> Dict:
"postprocess": lambda r: r[0],
}
elif hasattr(transformers, "FillMaskPipeline") and isinstance(
pipeline, transformers.FillMaskPipeline
pipeline, pipelines.fill_mask.FillMaskPipeline
):
pipeline_info = {
"inputs": components.Textbox(label="Input"),
@ -70,7 +72,7 @@ def load_from_pipeline(pipeline: transformers.Pipeline) -> Dict:
"postprocess": lambda r: {i["token_str"]: i["score"] for i in r},
}
elif hasattr(transformers, "ImageClassificationPipeline") and isinstance(
pipeline, transformers.ImageClassificationPipeline
pipeline, pipelines.image_classification.ImageClassificationPipeline
):
pipeline_info = {
"inputs": components.Image(type="filepath", label="Input Image"),
@ -79,7 +81,7 @@ def load_from_pipeline(pipeline: transformers.Pipeline) -> Dict:
"postprocess": lambda r: {i["label"].split(", ")[0]: i["score"] for i in r},
}
elif hasattr(transformers, "QuestionAnsweringPipeline") and isinstance(
pipeline, transformers.QuestionAnsweringPipeline
pipeline, pipelines.question_answering.QuestionAnsweringPipeline
):
pipeline_info = {
"inputs": [
@ -94,7 +96,7 @@ def load_from_pipeline(pipeline: transformers.Pipeline) -> Dict:
"postprocess": lambda r: (r["answer"], r["score"]),
}
elif hasattr(transformers, "SummarizationPipeline") and isinstance(
pipeline, transformers.SummarizationPipeline
pipeline, pipelines.text2text_generation.SummarizationPipeline
):
pipeline_info = {
"inputs": components.Textbox(lines=7, label="Input"),
@ -103,7 +105,7 @@ def load_from_pipeline(pipeline: transformers.Pipeline) -> Dict:
"postprocess": lambda r: r[0]["summary_text"],
}
elif hasattr(transformers, "TextClassificationPipeline") and isinstance(
pipeline, transformers.TextClassificationPipeline
pipeline, pipelines.text_classification.TextClassificationPipeline
):
pipeline_info = {
"inputs": components.Textbox(label="Input"),
@ -112,7 +114,7 @@ def load_from_pipeline(pipeline: transformers.Pipeline) -> Dict:
"postprocess": lambda r: {i["label"].split(", ")[0]: i["score"] for i in r},
}
elif hasattr(transformers, "TextGenerationPipeline") and isinstance(
pipeline, transformers.TextGenerationPipeline
pipeline, pipelines.text_generation.TextGenerationPipeline
):
pipeline_info = {
"inputs": components.Textbox(label="Input"),
@ -121,7 +123,7 @@ def load_from_pipeline(pipeline: transformers.Pipeline) -> Dict:
"postprocess": lambda r: r[0]["generated_text"],
}
elif hasattr(transformers, "TranslationPipeline") and isinstance(
pipeline, transformers.TranslationPipeline
pipeline, pipelines.text2text_generation.TranslationPipeline
):
pipeline_info = {
"inputs": components.Textbox(label="Input"),
@ -130,7 +132,7 @@ def load_from_pipeline(pipeline: transformers.Pipeline) -> Dict:
"postprocess": lambda r: r[0]["translation_text"],
}
elif hasattr(transformers, "Text2TextGenerationPipeline") and isinstance(
pipeline, transformers.Text2TextGenerationPipeline
pipeline, pipelines.text2text_generation.Text2TextGenerationPipeline
):
pipeline_info = {
"inputs": components.Textbox(label="Input"),
@ -139,7 +141,7 @@ def load_from_pipeline(pipeline: transformers.Pipeline) -> Dict:
"postprocess": lambda r: r[0]["generated_text"],
}
elif hasattr(transformers, "ZeroShotClassificationPipeline") and isinstance(
pipeline, transformers.ZeroShotClassificationPipeline
pipeline, pipelines.zero_shot_classification.ZeroShotClassificationPipeline
):
pipeline_info = {
"inputs": [
@ -167,9 +169,9 @@ def load_from_pipeline(pipeline: transformers.Pipeline) -> Dict:
if isinstance(
pipeline,
(
transformers.TextClassificationPipeline,
transformers.Text2TextGenerationPipeline,
transformers.TranslationPipeline,
pipelines.text_classification.TextClassificationPipeline,
pipelines.text2text_generation.Text2TextGenerationPipeline,
pipelines.text2text_generation.TranslationPipeline,
),
):
data = pipeline(*data)

View File

@ -34,11 +34,14 @@ with warnings.catch_warnings():
def to_binary(x: str | Dict) -> bytes:
"""Converts a base64 string or dictionary to a binary string that can be sent in a POST."""
if isinstance(x, dict) and not x.get("data"):
x = encode_url_or_file_to_base64(x["name"])
elif isinstance(x, dict) and x.get("data"):
x = x["data"]
return base64.b64decode(x.split(",")[1])
if isinstance(x, dict):
if x.get("data"):
base64str = x["data"]
else:
base64str = encode_url_or_file_to_base64(x["name"])
else:
base64str = x
return base64.b64decode(base64str.split(",")[1])
#########################
@ -46,31 +49,33 @@ def to_binary(x: str | Dict) -> bytes:
#########################
def decode_base64_to_image(encoding):
def decode_base64_to_image(encoding: str) -> Image.Image:
content = encoding.split(";")[1]
image_encoded = content.split(",")[1]
return Image.open(BytesIO(base64.b64decode(image_encoded)))
def encode_url_or_file_to_base64(path, encryption_key=None):
if utils.validate_url(path):
return encode_url_to_base64(path, encryption_key=encryption_key)
def encode_url_or_file_to_base64(path: str | Path, encryption_key: bytes | None = None):
if utils.validate_url(str(path)):
return encode_url_to_base64(str(path), encryption_key=encryption_key)
else:
return encode_file_to_base64(path, encryption_key=encryption_key)
return encode_file_to_base64(str(path), encryption_key=encryption_key)
def get_mimetype(filename):
def get_mimetype(filename: str) -> str | None:
mimetype = mimetypes.guess_type(filename)[0]
if mimetype is not None:
mimetype = mimetype.replace("x-wav", "wav").replace("x-flac", "flac")
return mimetype
def get_extension(encoding):
def get_extension(encoding: str) -> str | None:
encoding = encoding.replace("audio/wav", "audio/x-wav")
type = mimetypes.guess_type(encoding)[0]
if type == "audio/flac": # flac is not supported by mimetypes
return "flac"
elif type is None:
return None
extension = mimetypes.guess_extension(type)
if extension is not None and extension.startswith("."):
extension = extension[1:]
@ -176,7 +181,7 @@ def resize_and_crop(img, size, crop_type="center"):
resize[0] = img.size[0]
if size[1] is None:
resize[1] = img.size[1]
return ImageOps.fit(img, resize, centering=center)
return ImageOps.fit(img, resize, centering=center) # type: ignore
##################
@ -188,7 +193,7 @@ def audio_from_file(filename, crop_min=0, crop_max=100):
try:
audio = AudioSegment.from_file(filename)
except FileNotFoundError as e:
isfile = os.path.isfile(filename)
isfile = Path(filename).is_file()
msg = (
f"Cannot load audio from file: `{'ffprobe' if isfile else filename}` not found."
+ " Please install `ffmpeg` in your system to use non-WAV audio file formats"
@ -215,7 +220,8 @@ def audio_to_file(sample_rate, data, filename):
sample_width=data.dtype.itemsize,
channels=(1 if len(data.shape) == 1 else data.shape[1]),
)
audio.export(filename, format="wav").close()
file = audio.export(filename, format="wav")
file.close() # type: ignore
def convert_to_16_bit_wav(data):
@ -266,7 +272,7 @@ def decode_base64_to_file(
os.makedirs(dir, exist_ok=True)
data, extension = decode_base64_to_binary(encoding)
if file_path is not None and prefix is None:
filename = os.path.basename(file_path)
filename = Path(file_path).name
prefix = filename
if "." in filename:
prefix = filename[0 : filename.index(".")]
@ -341,7 +347,7 @@ class TempFileManager:
return sha1.hexdigest()
def get_prefix_and_extension(self, file_path_or_url: str) -> Tuple[str, str]:
file_name = os.path.basename(file_path_or_url)
file_name = Path(file_path_or_url).name
prefix, extension = file_name, None
if "." in file_name:
prefix = file_name[0 : file_name.index(".")]
@ -365,13 +371,13 @@ class TempFileManager:
"""Returns a temporary file path for a copy of the given file path if it does
not already exist. Otherwise returns the path to the existing temp file."""
f = tempfile.NamedTemporaryFile()
temp_dir, _ = os.path.split(f.name)
temp_dir = Path(f.name).parent
temp_file_path = self.get_temp_file_path(file_path)
f.name = os.path.join(temp_dir, temp_file_path)
full_temp_file_path = os.path.abspath(f.name)
f.name = str(temp_dir / temp_file_path)
full_temp_file_path = str(Path(f.name).resolve())
if not os.path.exists(full_temp_file_path):
if not Path(full_temp_file_path).exists():
shutil.copy2(file_path, full_temp_file_path)
self.temp_files.add(full_temp_file_path)
@ -381,13 +387,13 @@ class TempFileManager:
"""Downloads a file and makes a temporary file path for a copy if does not already
exist. Otherwise returns the path to the existing temp file."""
f = tempfile.NamedTemporaryFile()
temp_dir, _ = os.path.split(f.name)
temp_dir = Path(f.name).parent
temp_file_path = self.get_temp_url_path(url)
f.name = os.path.join(temp_dir, temp_file_path)
full_temp_file_path = os.path.abspath(f.name)
f.name = str(temp_dir / temp_file_path)
full_temp_file_path = str(Path(f.name).resolve())
if not os.path.exists(full_temp_file_path):
if not Path(full_temp_file_path).exists():
with requests.get(url, stream=True) as r:
with open(full_temp_file_path, "wb") as f:
shutil.copyfileobj(r.raw, f)
@ -399,7 +405,7 @@ class TempFileManager:
def create_tmp_copy_of_file(file_path, dir=None):
if dir is not None:
os.makedirs(dir, exist_ok=True)
file_name = os.path.basename(file_path)
file_name = Path(file_path).name
prefix, extension = file_name, None
if "." in file_name:
prefix = file_name[0 : file_name.index(".")]
@ -602,8 +608,8 @@ def _convert(image, dtype, force_copy=False, uniform=False):
imin_in = np.iinfo(dtype_in).min
imax_in = np.iinfo(dtype_in).max
if kind_out in "ui":
imin_out = np.iinfo(dtype_out).min
imax_out = np.iinfo(dtype_out).max
imin_out = np.iinfo(dtype_out).min # type: ignore
imax_out = np.iinfo(dtype_out).max # type: ignore
# any -> binary
if kind_out == "b":
@ -632,23 +638,23 @@ def _convert(image, dtype, force_copy=False, uniform=False):
if not uniform:
if kind_out == "u":
image_out = np.multiply(image, imax_out, dtype=computation_type)
image_out = np.multiply(image, imax_out, dtype=computation_type) # type: ignore
else:
image_out = np.multiply(
image, (imax_out - imin_out) / 2, dtype=computation_type
image, (imax_out - imin_out) / 2, dtype=computation_type # type: ignore
)
image_out -= 1.0 / 2.0
np.rint(image_out, out=image_out)
np.clip(image_out, imin_out, imax_out, out=image_out)
np.clip(image_out, imin_out, imax_out, out=image_out) # type: ignore
elif kind_out == "u":
image_out = np.multiply(image, imax_out + 1, dtype=computation_type)
np.clip(image_out, 0, imax_out, out=image_out)
image_out = np.multiply(image, imax_out + 1, dtype=computation_type) # type: ignore
np.clip(image_out, 0, imax_out, out=image_out) # type: ignore
else:
image_out = np.multiply(
image, (imax_out - imin_out + 1.0) / 2.0, dtype=computation_type
image, (imax_out - imin_out + 1.0) / 2.0, dtype=computation_type # type: ignore
)
np.floor(image_out, out=image_out)
np.clip(image_out, imin_out, imax_out, out=image_out)
np.clip(image_out, imin_out, imax_out, out=image_out) # type: ignore
return image_out.astype(dtype_out)
# signed/unsigned int -> float
@ -661,13 +667,13 @@ def _convert(image, dtype, force_copy=False, uniform=False):
if kind_in == "u":
# using np.divide or np.multiply doesn't copy the data
# until the computation time
image = np.multiply(image, 1.0 / imax_in, dtype=computation_type)
image = np.multiply(image, 1.0 / imax_in, dtype=computation_type) # type: ignore
# DirectX uses this conversion also for signed ints
# if imin_in:
# np.maximum(image, -1.0, out=image)
else:
image = np.add(image, 0.5, dtype=computation_type)
image *= 2 / (imax_in - imin_in)
image *= 2 / (imax_in - imin_in) # type: ignore
return np.asarray(image, dtype_out)
@ -693,9 +699,9 @@ def _convert(image, dtype, force_copy=False, uniform=False):
return _scale(image, 8 * itemsize_in - 1, 8 * itemsize_out - 1)
image = image.astype(_dtype_bits("i", itemsize_out * 8))
image -= imin_in
image -= imin_in # type: ignore
image = _scale(image, 8 * itemsize_in, 8 * itemsize_out, copy=False)
image += imin_out
image += imin_out # type: ignore
return image.astype(dtype_out)

View File

@ -5,7 +5,7 @@ import copy
import sys
import time
from collections import deque
from typing import Any, Deque, Dict, List, Optional, Tuple
from typing import Any, Deque, Dict, List, Tuple
import fastapi
@ -31,7 +31,7 @@ class Event:
self.progress: Progress | None = None
self.progress_pending: bool = False
async def disconnect(self, code=1000):
async def disconnect(self, code: int = 1000):
await self.websocket.close(code=code)
@ -41,7 +41,7 @@ class Queue:
live_updates: bool,
concurrency_count: int,
update_intervals: float,
max_size: Optional[int],
max_size: int | None,
blocks_dependencies: List,
):
self.event_queue: Deque[Event] = deque()
@ -54,7 +54,7 @@ class Queue:
self.server_path = None
self.duration_history_total = 0
self.duration_history_count = 0
self.avg_process_time = None
self.avg_process_time = 0
self.avg_concurrent_process_time = None
self.queue_duration = 1
self.live_updates = live_updates
@ -139,7 +139,7 @@ class Queue:
if job is None:
continue
for event in job:
if event.progress_pending:
if event.progress_pending and event.progress:
event.progress_pending = False
client_awake = await self.send_message(
event, event.progress.dict()
@ -290,11 +290,12 @@ class Queue:
"headers": dict(websocket.headers),
"query_params": dict(websocket.query_params),
"path_params": dict(websocket.path_params),
"client": dict(host=websocket.client.host, port=websocket.client.port),
"client": dict(host=websocket.client.host, port=websocket.client.port), # type: ignore
}
async def call_prediction(self, events: List[Event], batch: bool):
data = events[0].data
assert data is not None, "No event data"
token = events[0].token
data.event_id = events[0]._id if not batch else None
try:
@ -346,6 +347,7 @@ class Queue:
},
)
elif response.json.get("is_generating", False):
old_response = response
while response.json.get("is_generating", False):
# Python 3.7 doesn't have named tasks.
# In order to determine if a task was cancelled, we
@ -425,7 +427,7 @@ class Queue:
await self.clean_event(event)
return False
async def get_message(self, event) -> Optional[PredictBody]:
async def get_message(self, event) -> PredictBody | None:
try:
data = await event.websocket.receive_json()
return PredictBody(**data)

View File

@ -8,6 +8,7 @@ $ gradio app.py my_demo, to use variable names other than "demo"
import inspect
import os
import sys
from pathlib import Path
import gradio
from gradio import networking
@ -23,13 +24,13 @@ def run_in_reload_mode():
demo_name = args[1]
original_path = args[0]
abs_original_path = os.path.abspath(original_path)
path = os.path.normpath(original_path)
abs_original_path = Path(original_path).name
path = str(Path(original_path).resolve())
path = path.replace("/", ".")
path = path.replace("\\", ".")
filename = os.path.splitext(path)[0]
filename = Path(path).stem
gradio_folder = os.path.dirname(inspect.getfile(gradio))
gradio_folder = Path(inspect.getfile(gradio)).parent
port = networking.get_first_available_port(
networking.INITIAL_PORT_VALUE,
@ -42,16 +43,17 @@ def run_in_reload_mode():
message = "Watching:"
message_change_count = 0
if gradio_folder.strip():
if str(gradio_folder).strip():
command += f'--reload-dir "{gradio_folder}" '
message += f" '{gradio_folder}'"
message_change_count += 1
if os.path.dirname(abs_original_path).strip():
command += f'--reload-dir "{os.path.dirname(abs_original_path)}"'
abs_parent = Path(abs_original_path).parent
if str(abs_parent).strip():
command += f'--reload-dir "{abs_parent}"'
if message_change_count == 1:
message += ","
message += f" '{os.path.dirname(abs_original_path)}'"
message += f" '{abs_parent}'"
print(message + "\n")
os.system(command)

View File

@ -1,10 +1,10 @@
"""Implements a FastAPI server to run the gradio interface."""
"""Implements a FastAPI server to run the gradio interface. Note that some types in this
module use the Optional/Union notation so that they work correctly with pydantic."""
from __future__ import annotations
import asyncio
import inspect
import io
import json
import mimetypes
import os
@ -36,7 +36,7 @@ from starlette.responses import RedirectResponse
from starlette.websockets import WebSocketState
import gradio
from gradio import encryptor, utils
from gradio import utils
from gradio.data_classes import PredictBody, ResetBody
from gradio.documentation import document, set_documentation_group
from gradio.exceptions import Error
@ -97,9 +97,9 @@ class App(FastAPI):
"""
def __init__(self, **kwargs):
self.tokens = None
self.tokens = {}
self.auth = None
self.blocks: Optional[gradio.Blocks] = None
self.blocks: gradio.Blocks | None = None
self.state_holder = {}
self.iterators = defaultdict(dict)
self.lock = asyncio.Lock()
@ -124,6 +124,11 @@ class App(FastAPI):
self.favicon_path = blocks.favicon_path
self.tokens = {}
def get_blocks(self) -> gradio.Blocks:
if self.blocks is None:
raise ValueError("No Blocks has been configured for this app.")
return self.blocks
@staticmethod
def create_app(blocks: gradio.Blocks) -> App:
app = App(default_response_class=ORJSONResponse)
@ -151,7 +156,7 @@ class App(FastAPI):
status_code=status.HTTP_401_UNAUTHORIZED, detail="Not authenticated"
)
async def ws_login_check(websocket: WebSocket) -> str:
async def ws_login_check(websocket: WebSocket) -> Optional[str]:
token = websocket.cookies.get("access-token")
return token # token is returned to allow request in queue
@ -163,13 +168,15 @@ class App(FastAPI):
@app.get("/app_id")
@app.get("/app_id/")
def app_id(request: fastapi.Request) -> int:
return {"app_id": app.blocks.app_id}
def app_id(request: fastapi.Request) -> dict:
return {"app_id": app.get_blocks().app_id}
@app.post("/login")
@app.post("/login/")
def login(form_data: OAuth2PasswordRequestForm = Depends()):
username, password = form_data.username, form_data.password
if app.auth is None:
return RedirectResponse(url="/", status_code=status.HTTP_302_FOUND)
if (
not callable(app.auth)
and username in app.auth
@ -191,24 +198,25 @@ class App(FastAPI):
@app.get("/", response_class=HTMLResponse)
def main(request: fastapi.Request, user: str = Depends(get_current_user)):
mimetypes.add_type("application/javascript", ".js")
blocks = app.get_blocks()
if app.auth is None or not (user is None):
config = app.blocks.config
config = app.get_blocks().config
else:
config = {
"auth_required": True,
"auth_message": app.blocks.auth_message,
"auth_message": blocks.auth_message,
}
try:
template = (
"frontend/share.html" if app.blocks.share else "frontend/index.html"
"frontend/share.html" if blocks.share else "frontend/index.html"
)
return templates.TemplateResponse(
template, {"request": request, "config": config}
)
except TemplateNotFound:
if app.blocks.share:
if blocks.share:
raise ValueError(
"Did you install Gradio from source files? Share mode only "
"works when Gradio is installed through the pip package."
@ -222,7 +230,7 @@ class App(FastAPI):
@app.get("/config/", dependencies=[Depends(login_check)])
@app.get("/config", dependencies=[Depends(login_check)])
def get_config():
return app.blocks.config
return app.get_blocks().config
@app.get("/static/{path:path}")
def static_resource(path: str):
@ -240,31 +248,20 @@ class App(FastAPI):
@app.get("/favicon.ico")
async def favicon():
if app.blocks.favicon_path is None:
blocks = app.get_blocks()
if blocks.favicon_path is None:
return static_resource("img/logo.svg")
else:
return FileResponse(app.blocks.favicon_path)
return FileResponse(blocks.favicon_path)
@app.get("/file={path:path}", dependencies=[Depends(login_check)])
def file(path: str):
blocks = app.get_blocks()
if utils.validate_url(path):
return RedirectResponse(url=path, status_code=status.HTTP_302_FOUND)
if (
app.blocks.encrypt
and isinstance(app.blocks.examples, str)
and path.startswith(app.blocks.examples)
):
with open(safe_join(app.cwd, path), "rb") as encrypted_file:
encrypted_data = encrypted_file.read()
file_data = encryptor.decrypt(app.blocks.encryption_key, encrypted_data)
return FileResponse(
io.BytesIO(file_data), attachment_filename=os.path.basename(path)
)
if Path(app.cwd).resolve() in Path(
if Path(app.cwd).resolve() in Path(path).resolve().parents or Path(
path
).resolve().parents or os.path.abspath(path) in set().union(
*app.blocks.temp_file_sets
): # Need to use os.path.abspath in the second condition to be consistent with usage in TempFileManager
).resolve() in set().union(*blocks.temp_file_sets):
return FileResponse(
Path(path).resolve(), headers={"Accept-Ranges": "bytes"}
)
@ -289,14 +286,15 @@ class App(FastAPI):
async def run_predict(
body: PredictBody,
request: Request,
request: Request | List[Request],
fn_index_inferred: int,
username: str = Depends(get_current_user),
):
if hasattr(body, "session_hash"):
if body.session_hash not in app.state_holder:
app.state_holder[body.session_hash] = {
_id: deepcopy(getattr(block, "value", None))
for _id, block in app.blocks.blocks.items()
for _id, block in app.get_blocks().blocks.items()
if getattr(block, "stateful", False)
}
session_state = app.state_holder[body.session_hash]
@ -315,12 +313,12 @@ class App(FastAPI):
event_id = getattr(body, "event_id", None)
raw_input = body.data
fn_index = body.fn_index
batch = app.blocks.dependencies[fn_index]["batch"]
batch = app.get_blocks().dependencies[fn_index_inferred]["batch"]
if not (body.batched) and batch:
raw_input = [raw_input]
try:
output = await app.blocks.process_api(
fn_index=fn_index,
output = await app.get_blocks().process_api(
fn_index=fn_index_inferred,
inputs=raw_input,
request=request,
state=session_state,
@ -336,7 +334,7 @@ class App(FastAPI):
if isinstance(output, Error):
raise output
except BaseException as error:
show_error = app.blocks.show_error or isinstance(error, Error)
show_error = app.get_blocks().show_error or isinstance(error, Error)
traceback.print_exc()
return JSONResponse(
content={"error": str(error) if show_error else None},
@ -358,20 +356,23 @@ class App(FastAPI):
request: fastapi.Request,
username: str = Depends(get_current_user),
):
fn_index_inferred = None
if body.fn_index is None:
for i, fn in enumerate(app.blocks.dependencies):
for i, fn in enumerate(app.get_blocks().dependencies):
if fn["api_name"] == api_name:
body.fn_index = i
fn_index_inferred = i
break
if body.fn_index is None:
if fn_index_inferred is None:
return JSONResponse(
content={
"error": f"This app has no endpoint /api/{api_name}/."
},
status_code=500,
)
if not app.blocks.api_open and app.blocks.queue_enabled_for_fn(
body.fn_index
else:
fn_index_inferred = body.fn_index
if not app.get_blocks().api_open and app.get_blocks().queue_enabled_for_fn(
fn_index_inferred
):
if f"Bearer {app.queue_token}" != request.headers.get("Authorization"):
raise HTTPException(
@ -381,29 +382,36 @@ class App(FastAPI):
# If this fn_index cancels jobs, then the only input we need is the
# current session hash
if app.blocks.dependencies[body.fn_index]["cancels"]:
if app.get_blocks().dependencies[fn_index_inferred]["cancels"]:
body.data = [body.session_hash]
if body.request:
if body.batched:
request = [Request(**req) for req in body.request]
gr_request = [Request(**req) for req in body.request]
else:
request = Request(**body.request)
assert isinstance(body.request, dict)
gr_request = Request(**body.request)
else:
request = Request(request)
result = await run_predict(body=body, username=username, request=request)
gr_request = Request(request)
result = await run_predict(
body=body,
fn_index_inferred=fn_index_inferred,
username=username,
request=gr_request,
)
return result
@app.websocket("/queue/join")
async def join_queue(
websocket: WebSocket,
token: str = Depends(ws_login_check),
token: Optional[str] = Depends(ws_login_check),
):
blocks = app.get_blocks()
if app.auth is not None and token is None:
await websocket.close(code=status.WS_1008_POLICY_VIOLATION)
return
if app.blocks._queue.server_path is None:
if blocks._queue.server_path is None:
app_url = get_server_url_from_ws_url(str(websocket.url))
app.blocks._queue.set_url(app_url)
blocks._queue.set_url(app_url)
await websocket.accept()
# In order to cancel jobs, we need the session_hash and fn_index
# to create a unique id for each job
@ -414,27 +422,26 @@ class App(FastAPI):
)
# set the token into Event to allow using the same token for call_prediction
event.token = token
event.session_hash = session_info["session_hash"]
# Continuous events are not put in the queue so that they do not
# occupy the queue's resource as they are expected to run forever
if app.blocks.dependencies[event.fn_index].get("every", 0):
await cancel_tasks([f"{event.session_hash}_{event.fn_index}"])
await app.blocks._queue.reset_iterators(
event.session_hash, event.fn_index
)
if blocks.dependencies[event.fn_index].get("every", 0):
await cancel_tasks(set([f"{event.session_hash}_{event.fn_index}"]))
await blocks._queue.reset_iterators(event.session_hash, event.fn_index)
task = run_coro_in_background(
app.blocks._queue.process_events, [event], False
blocks._queue.process_events, [event], False
)
set_task_name(task, event.session_hash, event.fn_index, batch=False)
else:
rank = app.blocks._queue.push(event)
rank = blocks._queue.push(event)
if rank is None:
await app.blocks._queue.send_message(event, {"msg": "queue_full"})
await blocks._queue.send_message(event, {"msg": "queue_full"})
await event.disconnect()
return
estimation = app.blocks._queue.get_estimation()
await app.blocks._queue.send_estimation(event, estimation, rank)
estimation = blocks._queue.get_estimation()
await blocks._queue.send_estimation(event, estimation, rank)
while True:
await asyncio.sleep(60)
if websocket.application_state == WebSocketState.DISCONNECTED:
@ -446,19 +453,19 @@ class App(FastAPI):
response_model=Estimation,
)
async def get_queue_status():
return app.blocks._queue.get_estimation()
return app.get_blocks()._queue.get_estimation()
@app.get("/startup-events")
async def startup_events():
if not app.startup_events_triggered:
app.blocks.startup_events()
app.get_blocks().startup_events()
app.startup_events_triggered = True
return True
return False
@app.get("/robots.txt", response_class=PlainTextResponse)
def robots_txt():
if app.blocks.share:
if app.get_blocks().share:
return "User-agent: *\nDisallow: /"
else:
return "User-agent: *\nDisallow: "
@ -471,7 +478,7 @@ class App(FastAPI):
########
def safe_join(directory: str, path: str) -> Optional[str]:
def safe_join(directory: str, path: str) -> str | None:
"""Safely path to a base directory to avoid escaping the base directory.
Borrowed from: werkzeug.security.safe_join"""
_os_alt_seps: List[str] = list(
@ -480,6 +487,8 @@ def safe_join(directory: str, path: str) -> Optional[str]:
if path != "":
filename = posixpath.normpath(path)
else:
return directory
if (
any(sep in filename for sep in _os_alt_seps)
@ -495,7 +504,7 @@ def get_types(cls_set: List[Type]):
docset = []
types = []
for cls in cls_set:
doc = inspect.getdoc(cls)
doc = inspect.getdoc(cls) or ""
doc_lines = doc.split("\n")
for line in doc_lines:
if "value (" in line:
@ -505,10 +514,10 @@ def get_types(cls_set: List[Type]):
def get_server_url_from_ws_url(ws_url: str):
ws_url = urlparse(ws_url)
scheme = "http" if ws_url.scheme == "ws" else "https"
port = f":{ws_url.port}" if ws_url.port else ""
return f"{scheme}://{ws_url.hostname}{port}{ws_url.path.replace('queue/join', '')}"
ws_url_parsed = urlparse(ws_url)
scheme = "http" if ws_url_parsed.scheme == "ws" else "https"
port = f":{ws_url_parsed.port}" if ws_url_parsed.port else ""
return f"{scheme}://{ws_url_parsed.hostname}{port}{ws_url_parsed.path.replace('queue/join', '')}"
set_documentation_group("routes")
@ -553,7 +562,7 @@ class Request:
Parameters:
request: A fastapi.Request
"""
self.request: fastapi.Request = request
self.request = request
self.kwargs: Dict = kwargs
def dict_to_obj(self, d):
@ -578,7 +587,7 @@ def mount_gradio_app(
app: fastapi.FastAPI,
blocks: gradio.Blocks,
path: str,
gradio_api_url: Optional[str] = None,
gradio_api_url: str | None = None,
) -> fastapi.FastAPI:
"""Mount a gradio.Blocks to an existing FastAPI application.
@ -604,10 +613,10 @@ def mount_gradio_app(
@app.on_event("startup")
async def start_queue():
if gradio_app.blocks.enable_queue:
if gradio_app.get_blocks().enable_queue:
if gradio_api_url:
gradio_app.blocks._queue.set_url(gradio_api_url)
gradio_app.blocks.startup_events()
gradio_app.get_blocks()._queue.set_url(gradio_api_url)
gradio_app.get_blocks().startup_events()
app.mount(path, gradio_app)
return app

View File

@ -3,6 +3,7 @@ import os
import platform
import re
import subprocess
from pathlib import Path
from typing import List
VERSION = "0.1"
@ -26,11 +27,11 @@ class Tunnel:
# Check if the file exist
binary_name = f"frpc_{platform.system().lower()}_{machine.lower()}"
binary_path = os.path.join(os.path.dirname(__file__), binary_name)
binary_path = str(Path(__file__).parent / binary_name)
extension = ".exe" if os.name == "nt" else ""
if not os.path.exists(binary_path):
if not Path(binary_path).exists():
import stat
import requests
@ -91,8 +92,14 @@ class Tunnel:
atexit.register(self.kill)
url = ""
while url == "":
if self.proc.stdout is None:
continue
line = self.proc.stdout.readline()
line = line.decode("utf-8")
if "start proxy success" in line:
url = re.search("start proxy success: (.+)\n", line).group(1)
result = re.search("start proxy success: (.+)\n", line)
if result is None:
raise ValueError("Could not create share URL")
else:
url = result.group(1)
return url

View File

@ -6,4 +6,4 @@ pip_required
pip install --upgrade pip
pip install pyright
cd gradio
pyright blocks.py components.py context.py data_classes.py deprecation.py documentation.py encryptor.py events.py examples.py exceptions.py external.py external_utils.py serializing.py layouts.py flagging.py interface.py utils.py templates.py
pyright blocks.py components.py context.py data_classes.py deprecation.py documentation.py encryptor.py events.py examples.py exceptions.py external.py external_utils.py serializing.py layouts.py flagging.py interface.py mix.py networking.py pipelines.py processing_utils.py queueing.py reload.py routes.py serializing.py strings.py tunneling.py utils.py templates.py

View File

@ -242,13 +242,11 @@ class TestAuthenticatedRoutes:
response = client.post(
"/login",
data=dict(username="test", password="correct_password"),
follow_redirects=False,
)
assert response.status_code == 302
assert response.status_code == 200
response = client.post(
"/login",
data=dict(username="test", password="incorrect_password"),
follow_redirects=False,
)
assert response.status_code == 400