diff --git a/gradio/components.py b/gradio/components.py index 88a9bbbd14..653ff303f7 100644 --- a/gradio/components.py +++ b/gradio/components.py @@ -1435,6 +1435,7 @@ class Image( assert isinstance(x, dict) x, mask = x["image"], x["mask"] + assert isinstance(x, str) im = processing_utils.decode_base64_to_image(x) with warnings.catch_warnings(): warnings.simplefilter("ignore") diff --git a/gradio/data_classes.py b/gradio/data_classes.py index 7df772144a..1a80869af6 100644 --- a/gradio/data_classes.py +++ b/gradio/data_classes.py @@ -1,3 +1,6 @@ +"""Pydantic data models and other dataclasses. This is the only file that uses Optional[] +typing syntax instead of | None syntax to work with pydantic""" + from enum import Enum, auto from typing import Any, Dict, List, Optional, Union @@ -35,8 +38,8 @@ class Estimation(BaseModel): queue_size: int avg_event_process_time: Optional[float] avg_event_concurrent_process_time: Optional[float] - rank_eta: Optional[int] = None - queue_eta: int + rank_eta: Optional[float] = None + queue_eta: float class ProgressUnit(BaseModel): diff --git a/gradio/documentation.py b/gradio/documentation.py index 8384dcb1f9..ac98d61725 100644 --- a/gradio/documentation.py +++ b/gradio/documentation.py @@ -1,5 +1,9 @@ +"""Contains methods that generate documentation for Gradio functions and classes.""" + +from __future__ import annotations + import inspect -from typing import Callable, Dict, List, Optional, Tuple +from typing import Callable, Dict, List, Tuple classes_to_document = {} documentation_group = None @@ -30,7 +34,7 @@ def document(*fns): return inner_doc -def document_fn(fn: Callable) -> Tuple[str, List[Dict], Dict, Optional[str]]: +def document_fn(fn: Callable) -> Tuple[str, List[Dict], Dict, str | None]: """ Generates documentation for any function. Parameters: diff --git a/gradio/mix.py b/gradio/mix.py index 7b43852cbb..aba81eb83a 100644 --- a/gradio/mix.py +++ b/gradio/mix.py @@ -3,16 +3,12 @@ Ways to transform interfaces to produce new interfaces """ import asyncio import warnings -from typing import TYPE_CHECKING, List import gradio from gradio.documentation import document, set_documentation_group set_documentation_group("mix_interface") -if TYPE_CHECKING: # Only import for type checking (to avoid circular imports). - from gradio.components import IOComponent - @document() class Parallel(gradio.Interface): @@ -32,7 +28,7 @@ class Parallel(gradio.Interface): Returns: an Interface object comparing the given models """ - outputs: List[IOComponent] = [] + outputs = [] for interface in interfaces: if not (isinstance(interface, gradio.Interface)): @@ -44,7 +40,7 @@ class Parallel(gradio.Interface): async def parallel_fn(*args): return_values_with_durations = await asyncio.gather( - *[interface.call_function(0, args) for interface in interfaces] + *[interface.call_function(0, list(args)) for interface in interfaces] ) return_values = [rv["prediction"] for rv in return_values_with_durations] combined_list = [] @@ -97,7 +93,7 @@ class Series(gradio.Interface): ] # run all of predictions sequentially - data = (await interface.call_function(0, data))["prediction"] + data = (await interface.call_function(0, list(data)))["prediction"] if len(interface.output_components) == 1: data = [data] @@ -110,7 +106,7 @@ class Series(gradio.Interface): ) ] - if len(interface.output_components) == 1: + if len(interface.output_components) == 1: # type: ignore return data[0] return data diff --git a/gradio/networking.py b/gradio/networking.py index 79480f217a..e2ab8c6fc0 100644 --- a/gradio/networking.py +++ b/gradio/networking.py @@ -11,7 +11,6 @@ import time import warnings from typing import TYPE_CHECKING, Tuple -import fastapi import requests import uvicorn @@ -69,7 +68,7 @@ def get_first_available_port(initial: int, final: int) -> int: ) -def configure_app(app: fastapi.FastAPI, blocks: Blocks) -> fastapi.FastAPI: +def configure_app(app: App, blocks: Blocks) -> App: auth = blocks.auth if auth is not None: if not callable(auth): @@ -183,3 +182,4 @@ def url_ok(url: str) -> bool: time.sleep(0.500) except (ConnectionError, requests.exceptions.ConnectionError): return False + return False diff --git a/gradio/pipelines.py b/gradio/pipelines.py index 1dd840ae69..f974ed6d39 100644 --- a/gradio/pipelines.py +++ b/gradio/pipelines.py @@ -8,10 +8,10 @@ from typing import TYPE_CHECKING, Dict from gradio import components if TYPE_CHECKING: # Only import for type checking (is False at runtime). - import transformers + from transformers import pipelines -def load_from_pipeline(pipeline: transformers.Pipeline) -> Dict: +def load_from_pipeline(pipeline: pipelines.base.Pipeline) -> Dict: """ Gets the appropriate Interface kwargs for a given Hugging Face transformers.Pipeline. pipeline (transformers.Pipeline): the transformers.Pipeline from which to create an interface @@ -20,17 +20,18 @@ def load_from_pipeline(pipeline: transformers.Pipeline) -> Dict: """ try: import transformers + from transformers import pipelines except ImportError: raise ImportError( "transformers not installed. Please try `pip install transformers`" ) - if not isinstance(pipeline, transformers.Pipeline): + if not isinstance(pipeline, pipelines.base.Pipeline): raise ValueError("pipeline must be a transformers.Pipeline") # Handle the different pipelines. The has_attr() checks to make sure the pipeline exists in the # version of the transformers library that the user has installed. if hasattr(transformers, "AudioClassificationPipeline") and isinstance( - pipeline, transformers.AudioClassificationPipeline + pipeline, pipelines.audio_classification.AudioClassificationPipeline ): pipeline_info = { "inputs": components.Audio( @@ -41,7 +42,8 @@ def load_from_pipeline(pipeline: transformers.Pipeline) -> Dict: "postprocess": lambda r: {i["label"].split(", ")[0]: i["score"] for i in r}, } elif hasattr(transformers, "AutomaticSpeechRecognitionPipeline") and isinstance( - pipeline, transformers.AutomaticSpeechRecognitionPipeline + pipeline, + pipelines.automatic_speech_recognition.AutomaticSpeechRecognitionPipeline, ): pipeline_info = { "inputs": components.Audio( @@ -52,7 +54,7 @@ def load_from_pipeline(pipeline: transformers.Pipeline) -> Dict: "postprocess": lambda r: r["text"], } elif hasattr(transformers, "FeatureExtractionPipeline") and isinstance( - pipeline, transformers.FeatureExtractionPipeline + pipeline, pipelines.feature_extraction.FeatureExtractionPipeline ): pipeline_info = { "inputs": components.Textbox(label="Input"), @@ -61,7 +63,7 @@ def load_from_pipeline(pipeline: transformers.Pipeline) -> Dict: "postprocess": lambda r: r[0], } elif hasattr(transformers, "FillMaskPipeline") and isinstance( - pipeline, transformers.FillMaskPipeline + pipeline, pipelines.fill_mask.FillMaskPipeline ): pipeline_info = { "inputs": components.Textbox(label="Input"), @@ -70,7 +72,7 @@ def load_from_pipeline(pipeline: transformers.Pipeline) -> Dict: "postprocess": lambda r: {i["token_str"]: i["score"] for i in r}, } elif hasattr(transformers, "ImageClassificationPipeline") and isinstance( - pipeline, transformers.ImageClassificationPipeline + pipeline, pipelines.image_classification.ImageClassificationPipeline ): pipeline_info = { "inputs": components.Image(type="filepath", label="Input Image"), @@ -79,7 +81,7 @@ def load_from_pipeline(pipeline: transformers.Pipeline) -> Dict: "postprocess": lambda r: {i["label"].split(", ")[0]: i["score"] for i in r}, } elif hasattr(transformers, "QuestionAnsweringPipeline") and isinstance( - pipeline, transformers.QuestionAnsweringPipeline + pipeline, pipelines.question_answering.QuestionAnsweringPipeline ): pipeline_info = { "inputs": [ @@ -94,7 +96,7 @@ def load_from_pipeline(pipeline: transformers.Pipeline) -> Dict: "postprocess": lambda r: (r["answer"], r["score"]), } elif hasattr(transformers, "SummarizationPipeline") and isinstance( - pipeline, transformers.SummarizationPipeline + pipeline, pipelines.text2text_generation.SummarizationPipeline ): pipeline_info = { "inputs": components.Textbox(lines=7, label="Input"), @@ -103,7 +105,7 @@ def load_from_pipeline(pipeline: transformers.Pipeline) -> Dict: "postprocess": lambda r: r[0]["summary_text"], } elif hasattr(transformers, "TextClassificationPipeline") and isinstance( - pipeline, transformers.TextClassificationPipeline + pipeline, pipelines.text_classification.TextClassificationPipeline ): pipeline_info = { "inputs": components.Textbox(label="Input"), @@ -112,7 +114,7 @@ def load_from_pipeline(pipeline: transformers.Pipeline) -> Dict: "postprocess": lambda r: {i["label"].split(", ")[0]: i["score"] for i in r}, } elif hasattr(transformers, "TextGenerationPipeline") and isinstance( - pipeline, transformers.TextGenerationPipeline + pipeline, pipelines.text_generation.TextGenerationPipeline ): pipeline_info = { "inputs": components.Textbox(label="Input"), @@ -121,7 +123,7 @@ def load_from_pipeline(pipeline: transformers.Pipeline) -> Dict: "postprocess": lambda r: r[0]["generated_text"], } elif hasattr(transformers, "TranslationPipeline") and isinstance( - pipeline, transformers.TranslationPipeline + pipeline, pipelines.text2text_generation.TranslationPipeline ): pipeline_info = { "inputs": components.Textbox(label="Input"), @@ -130,7 +132,7 @@ def load_from_pipeline(pipeline: transformers.Pipeline) -> Dict: "postprocess": lambda r: r[0]["translation_text"], } elif hasattr(transformers, "Text2TextGenerationPipeline") and isinstance( - pipeline, transformers.Text2TextGenerationPipeline + pipeline, pipelines.text2text_generation.Text2TextGenerationPipeline ): pipeline_info = { "inputs": components.Textbox(label="Input"), @@ -139,7 +141,7 @@ def load_from_pipeline(pipeline: transformers.Pipeline) -> Dict: "postprocess": lambda r: r[0]["generated_text"], } elif hasattr(transformers, "ZeroShotClassificationPipeline") and isinstance( - pipeline, transformers.ZeroShotClassificationPipeline + pipeline, pipelines.zero_shot_classification.ZeroShotClassificationPipeline ): pipeline_info = { "inputs": [ @@ -167,9 +169,9 @@ def load_from_pipeline(pipeline: transformers.Pipeline) -> Dict: if isinstance( pipeline, ( - transformers.TextClassificationPipeline, - transformers.Text2TextGenerationPipeline, - transformers.TranslationPipeline, + pipelines.text_classification.TextClassificationPipeline, + pipelines.text2text_generation.Text2TextGenerationPipeline, + pipelines.text2text_generation.TranslationPipeline, ), ): data = pipeline(*data) diff --git a/gradio/processing_utils.py b/gradio/processing_utils.py index f408b37b3c..98e35365c9 100644 --- a/gradio/processing_utils.py +++ b/gradio/processing_utils.py @@ -34,11 +34,14 @@ with warnings.catch_warnings(): def to_binary(x: str | Dict) -> bytes: """Converts a base64 string or dictionary to a binary string that can be sent in a POST.""" - if isinstance(x, dict) and not x.get("data"): - x = encode_url_or_file_to_base64(x["name"]) - elif isinstance(x, dict) and x.get("data"): - x = x["data"] - return base64.b64decode(x.split(",")[1]) + if isinstance(x, dict): + if x.get("data"): + base64str = x["data"] + else: + base64str = encode_url_or_file_to_base64(x["name"]) + else: + base64str = x + return base64.b64decode(base64str.split(",")[1]) ######################### @@ -46,31 +49,33 @@ def to_binary(x: str | Dict) -> bytes: ######################### -def decode_base64_to_image(encoding): +def decode_base64_to_image(encoding: str) -> Image.Image: content = encoding.split(";")[1] image_encoded = content.split(",")[1] return Image.open(BytesIO(base64.b64decode(image_encoded))) -def encode_url_or_file_to_base64(path, encryption_key=None): - if utils.validate_url(path): - return encode_url_to_base64(path, encryption_key=encryption_key) +def encode_url_or_file_to_base64(path: str | Path, encryption_key: bytes | None = None): + if utils.validate_url(str(path)): + return encode_url_to_base64(str(path), encryption_key=encryption_key) else: - return encode_file_to_base64(path, encryption_key=encryption_key) + return encode_file_to_base64(str(path), encryption_key=encryption_key) -def get_mimetype(filename): +def get_mimetype(filename: str) -> str | None: mimetype = mimetypes.guess_type(filename)[0] if mimetype is not None: mimetype = mimetype.replace("x-wav", "wav").replace("x-flac", "flac") return mimetype -def get_extension(encoding): +def get_extension(encoding: str) -> str | None: encoding = encoding.replace("audio/wav", "audio/x-wav") type = mimetypes.guess_type(encoding)[0] if type == "audio/flac": # flac is not supported by mimetypes return "flac" + elif type is None: + return None extension = mimetypes.guess_extension(type) if extension is not None and extension.startswith("."): extension = extension[1:] @@ -176,7 +181,7 @@ def resize_and_crop(img, size, crop_type="center"): resize[0] = img.size[0] if size[1] is None: resize[1] = img.size[1] - return ImageOps.fit(img, resize, centering=center) + return ImageOps.fit(img, resize, centering=center) # type: ignore ################## @@ -188,7 +193,7 @@ def audio_from_file(filename, crop_min=0, crop_max=100): try: audio = AudioSegment.from_file(filename) except FileNotFoundError as e: - isfile = os.path.isfile(filename) + isfile = Path(filename).is_file() msg = ( f"Cannot load audio from file: `{'ffprobe' if isfile else filename}` not found." + " Please install `ffmpeg` in your system to use non-WAV audio file formats" @@ -215,7 +220,8 @@ def audio_to_file(sample_rate, data, filename): sample_width=data.dtype.itemsize, channels=(1 if len(data.shape) == 1 else data.shape[1]), ) - audio.export(filename, format="wav").close() + file = audio.export(filename, format="wav") + file.close() # type: ignore def convert_to_16_bit_wav(data): @@ -266,7 +272,7 @@ def decode_base64_to_file( os.makedirs(dir, exist_ok=True) data, extension = decode_base64_to_binary(encoding) if file_path is not None and prefix is None: - filename = os.path.basename(file_path) + filename = Path(file_path).name prefix = filename if "." in filename: prefix = filename[0 : filename.index(".")] @@ -341,7 +347,7 @@ class TempFileManager: return sha1.hexdigest() def get_prefix_and_extension(self, file_path_or_url: str) -> Tuple[str, str]: - file_name = os.path.basename(file_path_or_url) + file_name = Path(file_path_or_url).name prefix, extension = file_name, None if "." in file_name: prefix = file_name[0 : file_name.index(".")] @@ -365,13 +371,13 @@ class TempFileManager: """Returns a temporary file path for a copy of the given file path if it does not already exist. Otherwise returns the path to the existing temp file.""" f = tempfile.NamedTemporaryFile() - temp_dir, _ = os.path.split(f.name) + temp_dir = Path(f.name).parent temp_file_path = self.get_temp_file_path(file_path) - f.name = os.path.join(temp_dir, temp_file_path) - full_temp_file_path = os.path.abspath(f.name) + f.name = str(temp_dir / temp_file_path) + full_temp_file_path = str(Path(f.name).resolve()) - if not os.path.exists(full_temp_file_path): + if not Path(full_temp_file_path).exists(): shutil.copy2(file_path, full_temp_file_path) self.temp_files.add(full_temp_file_path) @@ -381,13 +387,13 @@ class TempFileManager: """Downloads a file and makes a temporary file path for a copy if does not already exist. Otherwise returns the path to the existing temp file.""" f = tempfile.NamedTemporaryFile() - temp_dir, _ = os.path.split(f.name) + temp_dir = Path(f.name).parent temp_file_path = self.get_temp_url_path(url) - f.name = os.path.join(temp_dir, temp_file_path) - full_temp_file_path = os.path.abspath(f.name) + f.name = str(temp_dir / temp_file_path) + full_temp_file_path = str(Path(f.name).resolve()) - if not os.path.exists(full_temp_file_path): + if not Path(full_temp_file_path).exists(): with requests.get(url, stream=True) as r: with open(full_temp_file_path, "wb") as f: shutil.copyfileobj(r.raw, f) @@ -399,7 +405,7 @@ class TempFileManager: def create_tmp_copy_of_file(file_path, dir=None): if dir is not None: os.makedirs(dir, exist_ok=True) - file_name = os.path.basename(file_path) + file_name = Path(file_path).name prefix, extension = file_name, None if "." in file_name: prefix = file_name[0 : file_name.index(".")] @@ -602,8 +608,8 @@ def _convert(image, dtype, force_copy=False, uniform=False): imin_in = np.iinfo(dtype_in).min imax_in = np.iinfo(dtype_in).max if kind_out in "ui": - imin_out = np.iinfo(dtype_out).min - imax_out = np.iinfo(dtype_out).max + imin_out = np.iinfo(dtype_out).min # type: ignore + imax_out = np.iinfo(dtype_out).max # type: ignore # any -> binary if kind_out == "b": @@ -632,23 +638,23 @@ def _convert(image, dtype, force_copy=False, uniform=False): if not uniform: if kind_out == "u": - image_out = np.multiply(image, imax_out, dtype=computation_type) + image_out = np.multiply(image, imax_out, dtype=computation_type) # type: ignore else: image_out = np.multiply( - image, (imax_out - imin_out) / 2, dtype=computation_type + image, (imax_out - imin_out) / 2, dtype=computation_type # type: ignore ) image_out -= 1.0 / 2.0 np.rint(image_out, out=image_out) - np.clip(image_out, imin_out, imax_out, out=image_out) + np.clip(image_out, imin_out, imax_out, out=image_out) # type: ignore elif kind_out == "u": - image_out = np.multiply(image, imax_out + 1, dtype=computation_type) - np.clip(image_out, 0, imax_out, out=image_out) + image_out = np.multiply(image, imax_out + 1, dtype=computation_type) # type: ignore + np.clip(image_out, 0, imax_out, out=image_out) # type: ignore else: image_out = np.multiply( - image, (imax_out - imin_out + 1.0) / 2.0, dtype=computation_type + image, (imax_out - imin_out + 1.0) / 2.0, dtype=computation_type # type: ignore ) np.floor(image_out, out=image_out) - np.clip(image_out, imin_out, imax_out, out=image_out) + np.clip(image_out, imin_out, imax_out, out=image_out) # type: ignore return image_out.astype(dtype_out) # signed/unsigned int -> float @@ -661,13 +667,13 @@ def _convert(image, dtype, force_copy=False, uniform=False): if kind_in == "u": # using np.divide or np.multiply doesn't copy the data # until the computation time - image = np.multiply(image, 1.0 / imax_in, dtype=computation_type) + image = np.multiply(image, 1.0 / imax_in, dtype=computation_type) # type: ignore # DirectX uses this conversion also for signed ints # if imin_in: # np.maximum(image, -1.0, out=image) else: image = np.add(image, 0.5, dtype=computation_type) - image *= 2 / (imax_in - imin_in) + image *= 2 / (imax_in - imin_in) # type: ignore return np.asarray(image, dtype_out) @@ -693,9 +699,9 @@ def _convert(image, dtype, force_copy=False, uniform=False): return _scale(image, 8 * itemsize_in - 1, 8 * itemsize_out - 1) image = image.astype(_dtype_bits("i", itemsize_out * 8)) - image -= imin_in + image -= imin_in # type: ignore image = _scale(image, 8 * itemsize_in, 8 * itemsize_out, copy=False) - image += imin_out + image += imin_out # type: ignore return image.astype(dtype_out) diff --git a/gradio/queueing.py b/gradio/queueing.py index fdd8b96343..1922449766 100644 --- a/gradio/queueing.py +++ b/gradio/queueing.py @@ -5,7 +5,7 @@ import copy import sys import time from collections import deque -from typing import Any, Deque, Dict, List, Optional, Tuple +from typing import Any, Deque, Dict, List, Tuple import fastapi @@ -31,7 +31,7 @@ class Event: self.progress: Progress | None = None self.progress_pending: bool = False - async def disconnect(self, code=1000): + async def disconnect(self, code: int = 1000): await self.websocket.close(code=code) @@ -41,7 +41,7 @@ class Queue: live_updates: bool, concurrency_count: int, update_intervals: float, - max_size: Optional[int], + max_size: int | None, blocks_dependencies: List, ): self.event_queue: Deque[Event] = deque() @@ -54,7 +54,7 @@ class Queue: self.server_path = None self.duration_history_total = 0 self.duration_history_count = 0 - self.avg_process_time = None + self.avg_process_time = 0 self.avg_concurrent_process_time = None self.queue_duration = 1 self.live_updates = live_updates @@ -139,7 +139,7 @@ class Queue: if job is None: continue for event in job: - if event.progress_pending: + if event.progress_pending and event.progress: event.progress_pending = False client_awake = await self.send_message( event, event.progress.dict() @@ -290,11 +290,12 @@ class Queue: "headers": dict(websocket.headers), "query_params": dict(websocket.query_params), "path_params": dict(websocket.path_params), - "client": dict(host=websocket.client.host, port=websocket.client.port), + "client": dict(host=websocket.client.host, port=websocket.client.port), # type: ignore } async def call_prediction(self, events: List[Event], batch: bool): data = events[0].data + assert data is not None, "No event data" token = events[0].token data.event_id = events[0]._id if not batch else None try: @@ -346,6 +347,7 @@ class Queue: }, ) elif response.json.get("is_generating", False): + old_response = response while response.json.get("is_generating", False): # Python 3.7 doesn't have named tasks. # In order to determine if a task was cancelled, we @@ -425,7 +427,7 @@ class Queue: await self.clean_event(event) return False - async def get_message(self, event) -> Optional[PredictBody]: + async def get_message(self, event) -> PredictBody | None: try: data = await event.websocket.receive_json() return PredictBody(**data) diff --git a/gradio/reload.py b/gradio/reload.py index 5219174280..c77b497595 100644 --- a/gradio/reload.py +++ b/gradio/reload.py @@ -8,6 +8,7 @@ $ gradio app.py my_demo, to use variable names other than "demo" import inspect import os import sys +from pathlib import Path import gradio from gradio import networking @@ -23,13 +24,13 @@ def run_in_reload_mode(): demo_name = args[1] original_path = args[0] - abs_original_path = os.path.abspath(original_path) - path = os.path.normpath(original_path) + abs_original_path = Path(original_path).name + path = str(Path(original_path).resolve()) path = path.replace("/", ".") path = path.replace("\\", ".") - filename = os.path.splitext(path)[0] + filename = Path(path).stem - gradio_folder = os.path.dirname(inspect.getfile(gradio)) + gradio_folder = Path(inspect.getfile(gradio)).parent port = networking.get_first_available_port( networking.INITIAL_PORT_VALUE, @@ -42,16 +43,17 @@ def run_in_reload_mode(): message = "Watching:" message_change_count = 0 - if gradio_folder.strip(): + if str(gradio_folder).strip(): command += f'--reload-dir "{gradio_folder}" ' message += f" '{gradio_folder}'" message_change_count += 1 - if os.path.dirname(abs_original_path).strip(): - command += f'--reload-dir "{os.path.dirname(abs_original_path)}"' + abs_parent = Path(abs_original_path).parent + if str(abs_parent).strip(): + command += f'--reload-dir "{abs_parent}"' if message_change_count == 1: message += "," - message += f" '{os.path.dirname(abs_original_path)}'" + message += f" '{abs_parent}'" print(message + "\n") os.system(command) diff --git a/gradio/routes.py b/gradio/routes.py index b76e77c0ec..969a431f15 100644 --- a/gradio/routes.py +++ b/gradio/routes.py @@ -1,10 +1,10 @@ -"""Implements a FastAPI server to run the gradio interface.""" +"""Implements a FastAPI server to run the gradio interface. Note that some types in this +module use the Optional/Union notation so that they work correctly with pydantic.""" from __future__ import annotations import asyncio import inspect -import io import json import mimetypes import os @@ -36,7 +36,7 @@ from starlette.responses import RedirectResponse from starlette.websockets import WebSocketState import gradio -from gradio import encryptor, utils +from gradio import utils from gradio.data_classes import PredictBody, ResetBody from gradio.documentation import document, set_documentation_group from gradio.exceptions import Error @@ -97,9 +97,9 @@ class App(FastAPI): """ def __init__(self, **kwargs): - self.tokens = None + self.tokens = {} self.auth = None - self.blocks: Optional[gradio.Blocks] = None + self.blocks: gradio.Blocks | None = None self.state_holder = {} self.iterators = defaultdict(dict) self.lock = asyncio.Lock() @@ -124,6 +124,11 @@ class App(FastAPI): self.favicon_path = blocks.favicon_path self.tokens = {} + def get_blocks(self) -> gradio.Blocks: + if self.blocks is None: + raise ValueError("No Blocks has been configured for this app.") + return self.blocks + @staticmethod def create_app(blocks: gradio.Blocks) -> App: app = App(default_response_class=ORJSONResponse) @@ -151,7 +156,7 @@ class App(FastAPI): status_code=status.HTTP_401_UNAUTHORIZED, detail="Not authenticated" ) - async def ws_login_check(websocket: WebSocket) -> str: + async def ws_login_check(websocket: WebSocket) -> Optional[str]: token = websocket.cookies.get("access-token") return token # token is returned to allow request in queue @@ -163,13 +168,15 @@ class App(FastAPI): @app.get("/app_id") @app.get("/app_id/") - def app_id(request: fastapi.Request) -> int: - return {"app_id": app.blocks.app_id} + def app_id(request: fastapi.Request) -> dict: + return {"app_id": app.get_blocks().app_id} @app.post("/login") @app.post("/login/") def login(form_data: OAuth2PasswordRequestForm = Depends()): username, password = form_data.username, form_data.password + if app.auth is None: + return RedirectResponse(url="/", status_code=status.HTTP_302_FOUND) if ( not callable(app.auth) and username in app.auth @@ -191,24 +198,25 @@ class App(FastAPI): @app.get("/", response_class=HTMLResponse) def main(request: fastapi.Request, user: str = Depends(get_current_user)): mimetypes.add_type("application/javascript", ".js") + blocks = app.get_blocks() if app.auth is None or not (user is None): - config = app.blocks.config + config = app.get_blocks().config else: config = { "auth_required": True, - "auth_message": app.blocks.auth_message, + "auth_message": blocks.auth_message, } try: template = ( - "frontend/share.html" if app.blocks.share else "frontend/index.html" + "frontend/share.html" if blocks.share else "frontend/index.html" ) return templates.TemplateResponse( template, {"request": request, "config": config} ) except TemplateNotFound: - if app.blocks.share: + if blocks.share: raise ValueError( "Did you install Gradio from source files? Share mode only " "works when Gradio is installed through the pip package." @@ -222,7 +230,7 @@ class App(FastAPI): @app.get("/config/", dependencies=[Depends(login_check)]) @app.get("/config", dependencies=[Depends(login_check)]) def get_config(): - return app.blocks.config + return app.get_blocks().config @app.get("/static/{path:path}") def static_resource(path: str): @@ -240,31 +248,20 @@ class App(FastAPI): @app.get("/favicon.ico") async def favicon(): - if app.blocks.favicon_path is None: + blocks = app.get_blocks() + if blocks.favicon_path is None: return static_resource("img/logo.svg") else: - return FileResponse(app.blocks.favicon_path) + return FileResponse(blocks.favicon_path) @app.get("/file={path:path}", dependencies=[Depends(login_check)]) def file(path: str): + blocks = app.get_blocks() if utils.validate_url(path): return RedirectResponse(url=path, status_code=status.HTTP_302_FOUND) - if ( - app.blocks.encrypt - and isinstance(app.blocks.examples, str) - and path.startswith(app.blocks.examples) - ): - with open(safe_join(app.cwd, path), "rb") as encrypted_file: - encrypted_data = encrypted_file.read() - file_data = encryptor.decrypt(app.blocks.encryption_key, encrypted_data) - return FileResponse( - io.BytesIO(file_data), attachment_filename=os.path.basename(path) - ) - if Path(app.cwd).resolve() in Path( + if Path(app.cwd).resolve() in Path(path).resolve().parents or Path( path - ).resolve().parents or os.path.abspath(path) in set().union( - *app.blocks.temp_file_sets - ): # Need to use os.path.abspath in the second condition to be consistent with usage in TempFileManager + ).resolve() in set().union(*blocks.temp_file_sets): return FileResponse( Path(path).resolve(), headers={"Accept-Ranges": "bytes"} ) @@ -289,14 +286,15 @@ class App(FastAPI): async def run_predict( body: PredictBody, - request: Request, + request: Request | List[Request], + fn_index_inferred: int, username: str = Depends(get_current_user), ): if hasattr(body, "session_hash"): if body.session_hash not in app.state_holder: app.state_holder[body.session_hash] = { _id: deepcopy(getattr(block, "value", None)) - for _id, block in app.blocks.blocks.items() + for _id, block in app.get_blocks().blocks.items() if getattr(block, "stateful", False) } session_state = app.state_holder[body.session_hash] @@ -315,12 +313,12 @@ class App(FastAPI): event_id = getattr(body, "event_id", None) raw_input = body.data fn_index = body.fn_index - batch = app.blocks.dependencies[fn_index]["batch"] + batch = app.get_blocks().dependencies[fn_index_inferred]["batch"] if not (body.batched) and batch: raw_input = [raw_input] try: - output = await app.blocks.process_api( - fn_index=fn_index, + output = await app.get_blocks().process_api( + fn_index=fn_index_inferred, inputs=raw_input, request=request, state=session_state, @@ -336,7 +334,7 @@ class App(FastAPI): if isinstance(output, Error): raise output except BaseException as error: - show_error = app.blocks.show_error or isinstance(error, Error) + show_error = app.get_blocks().show_error or isinstance(error, Error) traceback.print_exc() return JSONResponse( content={"error": str(error) if show_error else None}, @@ -358,20 +356,23 @@ class App(FastAPI): request: fastapi.Request, username: str = Depends(get_current_user), ): + fn_index_inferred = None if body.fn_index is None: - for i, fn in enumerate(app.blocks.dependencies): + for i, fn in enumerate(app.get_blocks().dependencies): if fn["api_name"] == api_name: - body.fn_index = i + fn_index_inferred = i break - if body.fn_index is None: + if fn_index_inferred is None: return JSONResponse( content={ "error": f"This app has no endpoint /api/{api_name}/." }, status_code=500, ) - if not app.blocks.api_open and app.blocks.queue_enabled_for_fn( - body.fn_index + else: + fn_index_inferred = body.fn_index + if not app.get_blocks().api_open and app.get_blocks().queue_enabled_for_fn( + fn_index_inferred ): if f"Bearer {app.queue_token}" != request.headers.get("Authorization"): raise HTTPException( @@ -381,29 +382,36 @@ class App(FastAPI): # If this fn_index cancels jobs, then the only input we need is the # current session hash - if app.blocks.dependencies[body.fn_index]["cancels"]: + if app.get_blocks().dependencies[fn_index_inferred]["cancels"]: body.data = [body.session_hash] if body.request: if body.batched: - request = [Request(**req) for req in body.request] + gr_request = [Request(**req) for req in body.request] else: - request = Request(**body.request) + assert isinstance(body.request, dict) + gr_request = Request(**body.request) else: - request = Request(request) - result = await run_predict(body=body, username=username, request=request) + gr_request = Request(request) + result = await run_predict( + body=body, + fn_index_inferred=fn_index_inferred, + username=username, + request=gr_request, + ) return result @app.websocket("/queue/join") async def join_queue( websocket: WebSocket, - token: str = Depends(ws_login_check), + token: Optional[str] = Depends(ws_login_check), ): + blocks = app.get_blocks() if app.auth is not None and token is None: await websocket.close(code=status.WS_1008_POLICY_VIOLATION) return - if app.blocks._queue.server_path is None: + if blocks._queue.server_path is None: app_url = get_server_url_from_ws_url(str(websocket.url)) - app.blocks._queue.set_url(app_url) + blocks._queue.set_url(app_url) await websocket.accept() # In order to cancel jobs, we need the session_hash and fn_index # to create a unique id for each job @@ -414,27 +422,26 @@ class App(FastAPI): ) # set the token into Event to allow using the same token for call_prediction event.token = token + event.session_hash = session_info["session_hash"] # Continuous events are not put in the queue so that they do not # occupy the queue's resource as they are expected to run forever - if app.blocks.dependencies[event.fn_index].get("every", 0): - await cancel_tasks([f"{event.session_hash}_{event.fn_index}"]) - await app.blocks._queue.reset_iterators( - event.session_hash, event.fn_index - ) + if blocks.dependencies[event.fn_index].get("every", 0): + await cancel_tasks(set([f"{event.session_hash}_{event.fn_index}"])) + await blocks._queue.reset_iterators(event.session_hash, event.fn_index) task = run_coro_in_background( - app.blocks._queue.process_events, [event], False + blocks._queue.process_events, [event], False ) set_task_name(task, event.session_hash, event.fn_index, batch=False) else: - rank = app.blocks._queue.push(event) + rank = blocks._queue.push(event) if rank is None: - await app.blocks._queue.send_message(event, {"msg": "queue_full"}) + await blocks._queue.send_message(event, {"msg": "queue_full"}) await event.disconnect() return - estimation = app.blocks._queue.get_estimation() - await app.blocks._queue.send_estimation(event, estimation, rank) + estimation = blocks._queue.get_estimation() + await blocks._queue.send_estimation(event, estimation, rank) while True: await asyncio.sleep(60) if websocket.application_state == WebSocketState.DISCONNECTED: @@ -446,19 +453,19 @@ class App(FastAPI): response_model=Estimation, ) async def get_queue_status(): - return app.blocks._queue.get_estimation() + return app.get_blocks()._queue.get_estimation() @app.get("/startup-events") async def startup_events(): if not app.startup_events_triggered: - app.blocks.startup_events() + app.get_blocks().startup_events() app.startup_events_triggered = True return True return False @app.get("/robots.txt", response_class=PlainTextResponse) def robots_txt(): - if app.blocks.share: + if app.get_blocks().share: return "User-agent: *\nDisallow: /" else: return "User-agent: *\nDisallow: " @@ -471,7 +478,7 @@ class App(FastAPI): ######## -def safe_join(directory: str, path: str) -> Optional[str]: +def safe_join(directory: str, path: str) -> str | None: """Safely path to a base directory to avoid escaping the base directory. Borrowed from: werkzeug.security.safe_join""" _os_alt_seps: List[str] = list( @@ -480,6 +487,8 @@ def safe_join(directory: str, path: str) -> Optional[str]: if path != "": filename = posixpath.normpath(path) + else: + return directory if ( any(sep in filename for sep in _os_alt_seps) @@ -495,7 +504,7 @@ def get_types(cls_set: List[Type]): docset = [] types = [] for cls in cls_set: - doc = inspect.getdoc(cls) + doc = inspect.getdoc(cls) or "" doc_lines = doc.split("\n") for line in doc_lines: if "value (" in line: @@ -505,10 +514,10 @@ def get_types(cls_set: List[Type]): def get_server_url_from_ws_url(ws_url: str): - ws_url = urlparse(ws_url) - scheme = "http" if ws_url.scheme == "ws" else "https" - port = f":{ws_url.port}" if ws_url.port else "" - return f"{scheme}://{ws_url.hostname}{port}{ws_url.path.replace('queue/join', '')}" + ws_url_parsed = urlparse(ws_url) + scheme = "http" if ws_url_parsed.scheme == "ws" else "https" + port = f":{ws_url_parsed.port}" if ws_url_parsed.port else "" + return f"{scheme}://{ws_url_parsed.hostname}{port}{ws_url_parsed.path.replace('queue/join', '')}" set_documentation_group("routes") @@ -553,7 +562,7 @@ class Request: Parameters: request: A fastapi.Request """ - self.request: fastapi.Request = request + self.request = request self.kwargs: Dict = kwargs def dict_to_obj(self, d): @@ -578,7 +587,7 @@ def mount_gradio_app( app: fastapi.FastAPI, blocks: gradio.Blocks, path: str, - gradio_api_url: Optional[str] = None, + gradio_api_url: str | None = None, ) -> fastapi.FastAPI: """Mount a gradio.Blocks to an existing FastAPI application. @@ -604,10 +613,10 @@ def mount_gradio_app( @app.on_event("startup") async def start_queue(): - if gradio_app.blocks.enable_queue: + if gradio_app.get_blocks().enable_queue: if gradio_api_url: - gradio_app.blocks._queue.set_url(gradio_api_url) - gradio_app.blocks.startup_events() + gradio_app.get_blocks()._queue.set_url(gradio_api_url) + gradio_app.get_blocks().startup_events() app.mount(path, gradio_app) return app diff --git a/gradio/tunneling.py b/gradio/tunneling.py index 2b1e7286e3..afb1a01d2c 100644 --- a/gradio/tunneling.py +++ b/gradio/tunneling.py @@ -3,6 +3,7 @@ import os import platform import re import subprocess +from pathlib import Path from typing import List VERSION = "0.1" @@ -26,11 +27,11 @@ class Tunnel: # Check if the file exist binary_name = f"frpc_{platform.system().lower()}_{machine.lower()}" - binary_path = os.path.join(os.path.dirname(__file__), binary_name) + binary_path = str(Path(__file__).parent / binary_name) extension = ".exe" if os.name == "nt" else "" - if not os.path.exists(binary_path): + if not Path(binary_path).exists(): import stat import requests @@ -91,8 +92,14 @@ class Tunnel: atexit.register(self.kill) url = "" while url == "": + if self.proc.stdout is None: + continue line = self.proc.stdout.readline() line = line.decode("utf-8") if "start proxy success" in line: - url = re.search("start proxy success: (.+)\n", line).group(1) + result = re.search("start proxy success: (.+)\n", line) + if result is None: + raise ValueError("Could not create share URL") + else: + url = result.group(1) return url diff --git a/scripts/type_check_backend.sh b/scripts/type_check_backend.sh index 80f78e12d6..badf9287bf 100644 --- a/scripts/type_check_backend.sh +++ b/scripts/type_check_backend.sh @@ -6,4 +6,4 @@ pip_required pip install --upgrade pip pip install pyright cd gradio -pyright blocks.py components.py context.py data_classes.py deprecation.py documentation.py encryptor.py events.py examples.py exceptions.py external.py external_utils.py serializing.py layouts.py flagging.py interface.py utils.py templates.py +pyright blocks.py components.py context.py data_classes.py deprecation.py documentation.py encryptor.py events.py examples.py exceptions.py external.py external_utils.py serializing.py layouts.py flagging.py interface.py mix.py networking.py pipelines.py processing_utils.py queueing.py reload.py routes.py serializing.py strings.py tunneling.py utils.py templates.py diff --git a/test/test_routes.py b/test/test_routes.py index 5122153c16..1e7fcbd82c 100644 --- a/test/test_routes.py +++ b/test/test_routes.py @@ -242,13 +242,11 @@ class TestAuthenticatedRoutes: response = client.post( "/login", data=dict(username="test", password="correct_password"), - follow_redirects=False, ) - assert response.status_code == 302 + assert response.status_code == 200 response = client.post( "/login", data=dict(username="test", password="incorrect_password"), - follow_redirects=False, ) assert response.status_code == 400