mirror of
https://github.com/gradio-app/gradio.git
synced 2025-02-17 11:29:58 +08:00
More typing! (#2906)
* started pathlib * blocks.py * more changes * fixes * typing * formatting * typing * renaming files * changelog * script * changelog * lint * routes * renamed * state * formatting * state * type check script * remove strictness * switched to pyright * switched to pyright * fixed flaky tests * fixed test xray * fixed load test * fixed blocks tests * formatting * fixed components test * uncomment tests * fixed interpretation tests * formatting * last tests hopefully * argh lint * component * fixed based on review * refactor * components.py t yping * components.py * formatting * lint script * merge * merge * lint * pathlib * lint * events too * lint script * fixing tests * lint * examples * serializing * more files * formatting * flagging.py * added to lint script * fixed tab * interface.py * attempt fix * refactoring interface * interface refactor * formatting * fix for live interfaces * lint * mix * mix * serialize fix * formatting * all demos queue * networking * added type check * processing_utils * more typing * formatting * type ignored processing utils * s * tunneling * add interpretation * more typing * queuing * serializing * undo interpretation * routes.py * formatting * component type * addressed review * lint * typing * documentation * fixing pydantic * routes * fixed typing in routes * fix tests
This commit is contained in:
parent
6a6e9175e1
commit
38c64a5b0e
@ -1435,6 +1435,7 @@ class Image(
|
||||
assert isinstance(x, dict)
|
||||
x, mask = x["image"], x["mask"]
|
||||
|
||||
assert isinstance(x, str)
|
||||
im = processing_utils.decode_base64_to_image(x)
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore")
|
||||
|
@ -1,3 +1,6 @@
|
||||
"""Pydantic data models and other dataclasses. This is the only file that uses Optional[]
|
||||
typing syntax instead of | None syntax to work with pydantic"""
|
||||
|
||||
from enum import Enum, auto
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
@ -35,8 +38,8 @@ class Estimation(BaseModel):
|
||||
queue_size: int
|
||||
avg_event_process_time: Optional[float]
|
||||
avg_event_concurrent_process_time: Optional[float]
|
||||
rank_eta: Optional[int] = None
|
||||
queue_eta: int
|
||||
rank_eta: Optional[float] = None
|
||||
queue_eta: float
|
||||
|
||||
|
||||
class ProgressUnit(BaseModel):
|
||||
|
@ -1,5 +1,9 @@
|
||||
"""Contains methods that generate documentation for Gradio functions and classes."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import inspect
|
||||
from typing import Callable, Dict, List, Optional, Tuple
|
||||
from typing import Callable, Dict, List, Tuple
|
||||
|
||||
classes_to_document = {}
|
||||
documentation_group = None
|
||||
@ -30,7 +34,7 @@ def document(*fns):
|
||||
return inner_doc
|
||||
|
||||
|
||||
def document_fn(fn: Callable) -> Tuple[str, List[Dict], Dict, Optional[str]]:
|
||||
def document_fn(fn: Callable) -> Tuple[str, List[Dict], Dict, str | None]:
|
||||
"""
|
||||
Generates documentation for any function.
|
||||
Parameters:
|
||||
|
@ -3,16 +3,12 @@ Ways to transform interfaces to produce new interfaces
|
||||
"""
|
||||
import asyncio
|
||||
import warnings
|
||||
from typing import TYPE_CHECKING, List
|
||||
|
||||
import gradio
|
||||
from gradio.documentation import document, set_documentation_group
|
||||
|
||||
set_documentation_group("mix_interface")
|
||||
|
||||
if TYPE_CHECKING: # Only import for type checking (to avoid circular imports).
|
||||
from gradio.components import IOComponent
|
||||
|
||||
|
||||
@document()
|
||||
class Parallel(gradio.Interface):
|
||||
@ -32,7 +28,7 @@ class Parallel(gradio.Interface):
|
||||
Returns:
|
||||
an Interface object comparing the given models
|
||||
"""
|
||||
outputs: List[IOComponent] = []
|
||||
outputs = []
|
||||
|
||||
for interface in interfaces:
|
||||
if not (isinstance(interface, gradio.Interface)):
|
||||
@ -44,7 +40,7 @@ class Parallel(gradio.Interface):
|
||||
|
||||
async def parallel_fn(*args):
|
||||
return_values_with_durations = await asyncio.gather(
|
||||
*[interface.call_function(0, args) for interface in interfaces]
|
||||
*[interface.call_function(0, list(args)) for interface in interfaces]
|
||||
)
|
||||
return_values = [rv["prediction"] for rv in return_values_with_durations]
|
||||
combined_list = []
|
||||
@ -97,7 +93,7 @@ class Series(gradio.Interface):
|
||||
]
|
||||
|
||||
# run all of predictions sequentially
|
||||
data = (await interface.call_function(0, data))["prediction"]
|
||||
data = (await interface.call_function(0, list(data)))["prediction"]
|
||||
if len(interface.output_components) == 1:
|
||||
data = [data]
|
||||
|
||||
@ -110,7 +106,7 @@ class Series(gradio.Interface):
|
||||
)
|
||||
]
|
||||
|
||||
if len(interface.output_components) == 1:
|
||||
if len(interface.output_components) == 1: # type: ignore
|
||||
return data[0]
|
||||
return data
|
||||
|
||||
|
@ -11,7 +11,6 @@ import time
|
||||
import warnings
|
||||
from typing import TYPE_CHECKING, Tuple
|
||||
|
||||
import fastapi
|
||||
import requests
|
||||
import uvicorn
|
||||
|
||||
@ -69,7 +68,7 @@ def get_first_available_port(initial: int, final: int) -> int:
|
||||
)
|
||||
|
||||
|
||||
def configure_app(app: fastapi.FastAPI, blocks: Blocks) -> fastapi.FastAPI:
|
||||
def configure_app(app: App, blocks: Blocks) -> App:
|
||||
auth = blocks.auth
|
||||
if auth is not None:
|
||||
if not callable(auth):
|
||||
@ -183,3 +182,4 @@ def url_ok(url: str) -> bool:
|
||||
time.sleep(0.500)
|
||||
except (ConnectionError, requests.exceptions.ConnectionError):
|
||||
return False
|
||||
return False
|
||||
|
@ -8,10 +8,10 @@ from typing import TYPE_CHECKING, Dict
|
||||
from gradio import components
|
||||
|
||||
if TYPE_CHECKING: # Only import for type checking (is False at runtime).
|
||||
import transformers
|
||||
from transformers import pipelines
|
||||
|
||||
|
||||
def load_from_pipeline(pipeline: transformers.Pipeline) -> Dict:
|
||||
def load_from_pipeline(pipeline: pipelines.base.Pipeline) -> Dict:
|
||||
"""
|
||||
Gets the appropriate Interface kwargs for a given Hugging Face transformers.Pipeline.
|
||||
pipeline (transformers.Pipeline): the transformers.Pipeline from which to create an interface
|
||||
@ -20,17 +20,18 @@ def load_from_pipeline(pipeline: transformers.Pipeline) -> Dict:
|
||||
"""
|
||||
try:
|
||||
import transformers
|
||||
from transformers import pipelines
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"transformers not installed. Please try `pip install transformers`"
|
||||
)
|
||||
if not isinstance(pipeline, transformers.Pipeline):
|
||||
if not isinstance(pipeline, pipelines.base.Pipeline):
|
||||
raise ValueError("pipeline must be a transformers.Pipeline")
|
||||
|
||||
# Handle the different pipelines. The has_attr() checks to make sure the pipeline exists in the
|
||||
# version of the transformers library that the user has installed.
|
||||
if hasattr(transformers, "AudioClassificationPipeline") and isinstance(
|
||||
pipeline, transformers.AudioClassificationPipeline
|
||||
pipeline, pipelines.audio_classification.AudioClassificationPipeline
|
||||
):
|
||||
pipeline_info = {
|
||||
"inputs": components.Audio(
|
||||
@ -41,7 +42,8 @@ def load_from_pipeline(pipeline: transformers.Pipeline) -> Dict:
|
||||
"postprocess": lambda r: {i["label"].split(", ")[0]: i["score"] for i in r},
|
||||
}
|
||||
elif hasattr(transformers, "AutomaticSpeechRecognitionPipeline") and isinstance(
|
||||
pipeline, transformers.AutomaticSpeechRecognitionPipeline
|
||||
pipeline,
|
||||
pipelines.automatic_speech_recognition.AutomaticSpeechRecognitionPipeline,
|
||||
):
|
||||
pipeline_info = {
|
||||
"inputs": components.Audio(
|
||||
@ -52,7 +54,7 @@ def load_from_pipeline(pipeline: transformers.Pipeline) -> Dict:
|
||||
"postprocess": lambda r: r["text"],
|
||||
}
|
||||
elif hasattr(transformers, "FeatureExtractionPipeline") and isinstance(
|
||||
pipeline, transformers.FeatureExtractionPipeline
|
||||
pipeline, pipelines.feature_extraction.FeatureExtractionPipeline
|
||||
):
|
||||
pipeline_info = {
|
||||
"inputs": components.Textbox(label="Input"),
|
||||
@ -61,7 +63,7 @@ def load_from_pipeline(pipeline: transformers.Pipeline) -> Dict:
|
||||
"postprocess": lambda r: r[0],
|
||||
}
|
||||
elif hasattr(transformers, "FillMaskPipeline") and isinstance(
|
||||
pipeline, transformers.FillMaskPipeline
|
||||
pipeline, pipelines.fill_mask.FillMaskPipeline
|
||||
):
|
||||
pipeline_info = {
|
||||
"inputs": components.Textbox(label="Input"),
|
||||
@ -70,7 +72,7 @@ def load_from_pipeline(pipeline: transformers.Pipeline) -> Dict:
|
||||
"postprocess": lambda r: {i["token_str"]: i["score"] for i in r},
|
||||
}
|
||||
elif hasattr(transformers, "ImageClassificationPipeline") and isinstance(
|
||||
pipeline, transformers.ImageClassificationPipeline
|
||||
pipeline, pipelines.image_classification.ImageClassificationPipeline
|
||||
):
|
||||
pipeline_info = {
|
||||
"inputs": components.Image(type="filepath", label="Input Image"),
|
||||
@ -79,7 +81,7 @@ def load_from_pipeline(pipeline: transformers.Pipeline) -> Dict:
|
||||
"postprocess": lambda r: {i["label"].split(", ")[0]: i["score"] for i in r},
|
||||
}
|
||||
elif hasattr(transformers, "QuestionAnsweringPipeline") and isinstance(
|
||||
pipeline, transformers.QuestionAnsweringPipeline
|
||||
pipeline, pipelines.question_answering.QuestionAnsweringPipeline
|
||||
):
|
||||
pipeline_info = {
|
||||
"inputs": [
|
||||
@ -94,7 +96,7 @@ def load_from_pipeline(pipeline: transformers.Pipeline) -> Dict:
|
||||
"postprocess": lambda r: (r["answer"], r["score"]),
|
||||
}
|
||||
elif hasattr(transformers, "SummarizationPipeline") and isinstance(
|
||||
pipeline, transformers.SummarizationPipeline
|
||||
pipeline, pipelines.text2text_generation.SummarizationPipeline
|
||||
):
|
||||
pipeline_info = {
|
||||
"inputs": components.Textbox(lines=7, label="Input"),
|
||||
@ -103,7 +105,7 @@ def load_from_pipeline(pipeline: transformers.Pipeline) -> Dict:
|
||||
"postprocess": lambda r: r[0]["summary_text"],
|
||||
}
|
||||
elif hasattr(transformers, "TextClassificationPipeline") and isinstance(
|
||||
pipeline, transformers.TextClassificationPipeline
|
||||
pipeline, pipelines.text_classification.TextClassificationPipeline
|
||||
):
|
||||
pipeline_info = {
|
||||
"inputs": components.Textbox(label="Input"),
|
||||
@ -112,7 +114,7 @@ def load_from_pipeline(pipeline: transformers.Pipeline) -> Dict:
|
||||
"postprocess": lambda r: {i["label"].split(", ")[0]: i["score"] for i in r},
|
||||
}
|
||||
elif hasattr(transformers, "TextGenerationPipeline") and isinstance(
|
||||
pipeline, transformers.TextGenerationPipeline
|
||||
pipeline, pipelines.text_generation.TextGenerationPipeline
|
||||
):
|
||||
pipeline_info = {
|
||||
"inputs": components.Textbox(label="Input"),
|
||||
@ -121,7 +123,7 @@ def load_from_pipeline(pipeline: transformers.Pipeline) -> Dict:
|
||||
"postprocess": lambda r: r[0]["generated_text"],
|
||||
}
|
||||
elif hasattr(transformers, "TranslationPipeline") and isinstance(
|
||||
pipeline, transformers.TranslationPipeline
|
||||
pipeline, pipelines.text2text_generation.TranslationPipeline
|
||||
):
|
||||
pipeline_info = {
|
||||
"inputs": components.Textbox(label="Input"),
|
||||
@ -130,7 +132,7 @@ def load_from_pipeline(pipeline: transformers.Pipeline) -> Dict:
|
||||
"postprocess": lambda r: r[0]["translation_text"],
|
||||
}
|
||||
elif hasattr(transformers, "Text2TextGenerationPipeline") and isinstance(
|
||||
pipeline, transformers.Text2TextGenerationPipeline
|
||||
pipeline, pipelines.text2text_generation.Text2TextGenerationPipeline
|
||||
):
|
||||
pipeline_info = {
|
||||
"inputs": components.Textbox(label="Input"),
|
||||
@ -139,7 +141,7 @@ def load_from_pipeline(pipeline: transformers.Pipeline) -> Dict:
|
||||
"postprocess": lambda r: r[0]["generated_text"],
|
||||
}
|
||||
elif hasattr(transformers, "ZeroShotClassificationPipeline") and isinstance(
|
||||
pipeline, transformers.ZeroShotClassificationPipeline
|
||||
pipeline, pipelines.zero_shot_classification.ZeroShotClassificationPipeline
|
||||
):
|
||||
pipeline_info = {
|
||||
"inputs": [
|
||||
@ -167,9 +169,9 @@ def load_from_pipeline(pipeline: transformers.Pipeline) -> Dict:
|
||||
if isinstance(
|
||||
pipeline,
|
||||
(
|
||||
transformers.TextClassificationPipeline,
|
||||
transformers.Text2TextGenerationPipeline,
|
||||
transformers.TranslationPipeline,
|
||||
pipelines.text_classification.TextClassificationPipeline,
|
||||
pipelines.text2text_generation.Text2TextGenerationPipeline,
|
||||
pipelines.text2text_generation.TranslationPipeline,
|
||||
),
|
||||
):
|
||||
data = pipeline(*data)
|
||||
|
@ -34,11 +34,14 @@ with warnings.catch_warnings():
|
||||
|
||||
def to_binary(x: str | Dict) -> bytes:
|
||||
"""Converts a base64 string or dictionary to a binary string that can be sent in a POST."""
|
||||
if isinstance(x, dict) and not x.get("data"):
|
||||
x = encode_url_or_file_to_base64(x["name"])
|
||||
elif isinstance(x, dict) and x.get("data"):
|
||||
x = x["data"]
|
||||
return base64.b64decode(x.split(",")[1])
|
||||
if isinstance(x, dict):
|
||||
if x.get("data"):
|
||||
base64str = x["data"]
|
||||
else:
|
||||
base64str = encode_url_or_file_to_base64(x["name"])
|
||||
else:
|
||||
base64str = x
|
||||
return base64.b64decode(base64str.split(",")[1])
|
||||
|
||||
|
||||
#########################
|
||||
@ -46,31 +49,33 @@ def to_binary(x: str | Dict) -> bytes:
|
||||
#########################
|
||||
|
||||
|
||||
def decode_base64_to_image(encoding):
|
||||
def decode_base64_to_image(encoding: str) -> Image.Image:
|
||||
content = encoding.split(";")[1]
|
||||
image_encoded = content.split(",")[1]
|
||||
return Image.open(BytesIO(base64.b64decode(image_encoded)))
|
||||
|
||||
|
||||
def encode_url_or_file_to_base64(path, encryption_key=None):
|
||||
if utils.validate_url(path):
|
||||
return encode_url_to_base64(path, encryption_key=encryption_key)
|
||||
def encode_url_or_file_to_base64(path: str | Path, encryption_key: bytes | None = None):
|
||||
if utils.validate_url(str(path)):
|
||||
return encode_url_to_base64(str(path), encryption_key=encryption_key)
|
||||
else:
|
||||
return encode_file_to_base64(path, encryption_key=encryption_key)
|
||||
return encode_file_to_base64(str(path), encryption_key=encryption_key)
|
||||
|
||||
|
||||
def get_mimetype(filename):
|
||||
def get_mimetype(filename: str) -> str | None:
|
||||
mimetype = mimetypes.guess_type(filename)[0]
|
||||
if mimetype is not None:
|
||||
mimetype = mimetype.replace("x-wav", "wav").replace("x-flac", "flac")
|
||||
return mimetype
|
||||
|
||||
|
||||
def get_extension(encoding):
|
||||
def get_extension(encoding: str) -> str | None:
|
||||
encoding = encoding.replace("audio/wav", "audio/x-wav")
|
||||
type = mimetypes.guess_type(encoding)[0]
|
||||
if type == "audio/flac": # flac is not supported by mimetypes
|
||||
return "flac"
|
||||
elif type is None:
|
||||
return None
|
||||
extension = mimetypes.guess_extension(type)
|
||||
if extension is not None and extension.startswith("."):
|
||||
extension = extension[1:]
|
||||
@ -176,7 +181,7 @@ def resize_and_crop(img, size, crop_type="center"):
|
||||
resize[0] = img.size[0]
|
||||
if size[1] is None:
|
||||
resize[1] = img.size[1]
|
||||
return ImageOps.fit(img, resize, centering=center)
|
||||
return ImageOps.fit(img, resize, centering=center) # type: ignore
|
||||
|
||||
|
||||
##################
|
||||
@ -188,7 +193,7 @@ def audio_from_file(filename, crop_min=0, crop_max=100):
|
||||
try:
|
||||
audio = AudioSegment.from_file(filename)
|
||||
except FileNotFoundError as e:
|
||||
isfile = os.path.isfile(filename)
|
||||
isfile = Path(filename).is_file()
|
||||
msg = (
|
||||
f"Cannot load audio from file: `{'ffprobe' if isfile else filename}` not found."
|
||||
+ " Please install `ffmpeg` in your system to use non-WAV audio file formats"
|
||||
@ -215,7 +220,8 @@ def audio_to_file(sample_rate, data, filename):
|
||||
sample_width=data.dtype.itemsize,
|
||||
channels=(1 if len(data.shape) == 1 else data.shape[1]),
|
||||
)
|
||||
audio.export(filename, format="wav").close()
|
||||
file = audio.export(filename, format="wav")
|
||||
file.close() # type: ignore
|
||||
|
||||
|
||||
def convert_to_16_bit_wav(data):
|
||||
@ -266,7 +272,7 @@ def decode_base64_to_file(
|
||||
os.makedirs(dir, exist_ok=True)
|
||||
data, extension = decode_base64_to_binary(encoding)
|
||||
if file_path is not None and prefix is None:
|
||||
filename = os.path.basename(file_path)
|
||||
filename = Path(file_path).name
|
||||
prefix = filename
|
||||
if "." in filename:
|
||||
prefix = filename[0 : filename.index(".")]
|
||||
@ -341,7 +347,7 @@ class TempFileManager:
|
||||
return sha1.hexdigest()
|
||||
|
||||
def get_prefix_and_extension(self, file_path_or_url: str) -> Tuple[str, str]:
|
||||
file_name = os.path.basename(file_path_or_url)
|
||||
file_name = Path(file_path_or_url).name
|
||||
prefix, extension = file_name, None
|
||||
if "." in file_name:
|
||||
prefix = file_name[0 : file_name.index(".")]
|
||||
@ -365,13 +371,13 @@ class TempFileManager:
|
||||
"""Returns a temporary file path for a copy of the given file path if it does
|
||||
not already exist. Otherwise returns the path to the existing temp file."""
|
||||
f = tempfile.NamedTemporaryFile()
|
||||
temp_dir, _ = os.path.split(f.name)
|
||||
temp_dir = Path(f.name).parent
|
||||
|
||||
temp_file_path = self.get_temp_file_path(file_path)
|
||||
f.name = os.path.join(temp_dir, temp_file_path)
|
||||
full_temp_file_path = os.path.abspath(f.name)
|
||||
f.name = str(temp_dir / temp_file_path)
|
||||
full_temp_file_path = str(Path(f.name).resolve())
|
||||
|
||||
if not os.path.exists(full_temp_file_path):
|
||||
if not Path(full_temp_file_path).exists():
|
||||
shutil.copy2(file_path, full_temp_file_path)
|
||||
|
||||
self.temp_files.add(full_temp_file_path)
|
||||
@ -381,13 +387,13 @@ class TempFileManager:
|
||||
"""Downloads a file and makes a temporary file path for a copy if does not already
|
||||
exist. Otherwise returns the path to the existing temp file."""
|
||||
f = tempfile.NamedTemporaryFile()
|
||||
temp_dir, _ = os.path.split(f.name)
|
||||
temp_dir = Path(f.name).parent
|
||||
|
||||
temp_file_path = self.get_temp_url_path(url)
|
||||
f.name = os.path.join(temp_dir, temp_file_path)
|
||||
full_temp_file_path = os.path.abspath(f.name)
|
||||
f.name = str(temp_dir / temp_file_path)
|
||||
full_temp_file_path = str(Path(f.name).resolve())
|
||||
|
||||
if not os.path.exists(full_temp_file_path):
|
||||
if not Path(full_temp_file_path).exists():
|
||||
with requests.get(url, stream=True) as r:
|
||||
with open(full_temp_file_path, "wb") as f:
|
||||
shutil.copyfileobj(r.raw, f)
|
||||
@ -399,7 +405,7 @@ class TempFileManager:
|
||||
def create_tmp_copy_of_file(file_path, dir=None):
|
||||
if dir is not None:
|
||||
os.makedirs(dir, exist_ok=True)
|
||||
file_name = os.path.basename(file_path)
|
||||
file_name = Path(file_path).name
|
||||
prefix, extension = file_name, None
|
||||
if "." in file_name:
|
||||
prefix = file_name[0 : file_name.index(".")]
|
||||
@ -602,8 +608,8 @@ def _convert(image, dtype, force_copy=False, uniform=False):
|
||||
imin_in = np.iinfo(dtype_in).min
|
||||
imax_in = np.iinfo(dtype_in).max
|
||||
if kind_out in "ui":
|
||||
imin_out = np.iinfo(dtype_out).min
|
||||
imax_out = np.iinfo(dtype_out).max
|
||||
imin_out = np.iinfo(dtype_out).min # type: ignore
|
||||
imax_out = np.iinfo(dtype_out).max # type: ignore
|
||||
|
||||
# any -> binary
|
||||
if kind_out == "b":
|
||||
@ -632,23 +638,23 @@ def _convert(image, dtype, force_copy=False, uniform=False):
|
||||
|
||||
if not uniform:
|
||||
if kind_out == "u":
|
||||
image_out = np.multiply(image, imax_out, dtype=computation_type)
|
||||
image_out = np.multiply(image, imax_out, dtype=computation_type) # type: ignore
|
||||
else:
|
||||
image_out = np.multiply(
|
||||
image, (imax_out - imin_out) / 2, dtype=computation_type
|
||||
image, (imax_out - imin_out) / 2, dtype=computation_type # type: ignore
|
||||
)
|
||||
image_out -= 1.0 / 2.0
|
||||
np.rint(image_out, out=image_out)
|
||||
np.clip(image_out, imin_out, imax_out, out=image_out)
|
||||
np.clip(image_out, imin_out, imax_out, out=image_out) # type: ignore
|
||||
elif kind_out == "u":
|
||||
image_out = np.multiply(image, imax_out + 1, dtype=computation_type)
|
||||
np.clip(image_out, 0, imax_out, out=image_out)
|
||||
image_out = np.multiply(image, imax_out + 1, dtype=computation_type) # type: ignore
|
||||
np.clip(image_out, 0, imax_out, out=image_out) # type: ignore
|
||||
else:
|
||||
image_out = np.multiply(
|
||||
image, (imax_out - imin_out + 1.0) / 2.0, dtype=computation_type
|
||||
image, (imax_out - imin_out + 1.0) / 2.0, dtype=computation_type # type: ignore
|
||||
)
|
||||
np.floor(image_out, out=image_out)
|
||||
np.clip(image_out, imin_out, imax_out, out=image_out)
|
||||
np.clip(image_out, imin_out, imax_out, out=image_out) # type: ignore
|
||||
return image_out.astype(dtype_out)
|
||||
|
||||
# signed/unsigned int -> float
|
||||
@ -661,13 +667,13 @@ def _convert(image, dtype, force_copy=False, uniform=False):
|
||||
if kind_in == "u":
|
||||
# using np.divide or np.multiply doesn't copy the data
|
||||
# until the computation time
|
||||
image = np.multiply(image, 1.0 / imax_in, dtype=computation_type)
|
||||
image = np.multiply(image, 1.0 / imax_in, dtype=computation_type) # type: ignore
|
||||
# DirectX uses this conversion also for signed ints
|
||||
# if imin_in:
|
||||
# np.maximum(image, -1.0, out=image)
|
||||
else:
|
||||
image = np.add(image, 0.5, dtype=computation_type)
|
||||
image *= 2 / (imax_in - imin_in)
|
||||
image *= 2 / (imax_in - imin_in) # type: ignore
|
||||
|
||||
return np.asarray(image, dtype_out)
|
||||
|
||||
@ -693,9 +699,9 @@ def _convert(image, dtype, force_copy=False, uniform=False):
|
||||
return _scale(image, 8 * itemsize_in - 1, 8 * itemsize_out - 1)
|
||||
|
||||
image = image.astype(_dtype_bits("i", itemsize_out * 8))
|
||||
image -= imin_in
|
||||
image -= imin_in # type: ignore
|
||||
image = _scale(image, 8 * itemsize_in, 8 * itemsize_out, copy=False)
|
||||
image += imin_out
|
||||
image += imin_out # type: ignore
|
||||
return image.astype(dtype_out)
|
||||
|
||||
|
||||
|
@ -5,7 +5,7 @@ import copy
|
||||
import sys
|
||||
import time
|
||||
from collections import deque
|
||||
from typing import Any, Deque, Dict, List, Optional, Tuple
|
||||
from typing import Any, Deque, Dict, List, Tuple
|
||||
|
||||
import fastapi
|
||||
|
||||
@ -31,7 +31,7 @@ class Event:
|
||||
self.progress: Progress | None = None
|
||||
self.progress_pending: bool = False
|
||||
|
||||
async def disconnect(self, code=1000):
|
||||
async def disconnect(self, code: int = 1000):
|
||||
await self.websocket.close(code=code)
|
||||
|
||||
|
||||
@ -41,7 +41,7 @@ class Queue:
|
||||
live_updates: bool,
|
||||
concurrency_count: int,
|
||||
update_intervals: float,
|
||||
max_size: Optional[int],
|
||||
max_size: int | None,
|
||||
blocks_dependencies: List,
|
||||
):
|
||||
self.event_queue: Deque[Event] = deque()
|
||||
@ -54,7 +54,7 @@ class Queue:
|
||||
self.server_path = None
|
||||
self.duration_history_total = 0
|
||||
self.duration_history_count = 0
|
||||
self.avg_process_time = None
|
||||
self.avg_process_time = 0
|
||||
self.avg_concurrent_process_time = None
|
||||
self.queue_duration = 1
|
||||
self.live_updates = live_updates
|
||||
@ -139,7 +139,7 @@ class Queue:
|
||||
if job is None:
|
||||
continue
|
||||
for event in job:
|
||||
if event.progress_pending:
|
||||
if event.progress_pending and event.progress:
|
||||
event.progress_pending = False
|
||||
client_awake = await self.send_message(
|
||||
event, event.progress.dict()
|
||||
@ -290,11 +290,12 @@ class Queue:
|
||||
"headers": dict(websocket.headers),
|
||||
"query_params": dict(websocket.query_params),
|
||||
"path_params": dict(websocket.path_params),
|
||||
"client": dict(host=websocket.client.host, port=websocket.client.port),
|
||||
"client": dict(host=websocket.client.host, port=websocket.client.port), # type: ignore
|
||||
}
|
||||
|
||||
async def call_prediction(self, events: List[Event], batch: bool):
|
||||
data = events[0].data
|
||||
assert data is not None, "No event data"
|
||||
token = events[0].token
|
||||
data.event_id = events[0]._id if not batch else None
|
||||
try:
|
||||
@ -346,6 +347,7 @@ class Queue:
|
||||
},
|
||||
)
|
||||
elif response.json.get("is_generating", False):
|
||||
old_response = response
|
||||
while response.json.get("is_generating", False):
|
||||
# Python 3.7 doesn't have named tasks.
|
||||
# In order to determine if a task was cancelled, we
|
||||
@ -425,7 +427,7 @@ class Queue:
|
||||
await self.clean_event(event)
|
||||
return False
|
||||
|
||||
async def get_message(self, event) -> Optional[PredictBody]:
|
||||
async def get_message(self, event) -> PredictBody | None:
|
||||
try:
|
||||
data = await event.websocket.receive_json()
|
||||
return PredictBody(**data)
|
||||
|
@ -8,6 +8,7 @@ $ gradio app.py my_demo, to use variable names other than "demo"
|
||||
import inspect
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import gradio
|
||||
from gradio import networking
|
||||
@ -23,13 +24,13 @@ def run_in_reload_mode():
|
||||
demo_name = args[1]
|
||||
|
||||
original_path = args[0]
|
||||
abs_original_path = os.path.abspath(original_path)
|
||||
path = os.path.normpath(original_path)
|
||||
abs_original_path = Path(original_path).name
|
||||
path = str(Path(original_path).resolve())
|
||||
path = path.replace("/", ".")
|
||||
path = path.replace("\\", ".")
|
||||
filename = os.path.splitext(path)[0]
|
||||
filename = Path(path).stem
|
||||
|
||||
gradio_folder = os.path.dirname(inspect.getfile(gradio))
|
||||
gradio_folder = Path(inspect.getfile(gradio)).parent
|
||||
|
||||
port = networking.get_first_available_port(
|
||||
networking.INITIAL_PORT_VALUE,
|
||||
@ -42,16 +43,17 @@ def run_in_reload_mode():
|
||||
message = "Watching:"
|
||||
|
||||
message_change_count = 0
|
||||
if gradio_folder.strip():
|
||||
if str(gradio_folder).strip():
|
||||
command += f'--reload-dir "{gradio_folder}" '
|
||||
message += f" '{gradio_folder}'"
|
||||
message_change_count += 1
|
||||
|
||||
if os.path.dirname(abs_original_path).strip():
|
||||
command += f'--reload-dir "{os.path.dirname(abs_original_path)}"'
|
||||
abs_parent = Path(abs_original_path).parent
|
||||
if str(abs_parent).strip():
|
||||
command += f'--reload-dir "{abs_parent}"'
|
||||
if message_change_count == 1:
|
||||
message += ","
|
||||
message += f" '{os.path.dirname(abs_original_path)}'"
|
||||
message += f" '{abs_parent}'"
|
||||
|
||||
print(message + "\n")
|
||||
os.system(command)
|
||||
|
155
gradio/routes.py
155
gradio/routes.py
@ -1,10 +1,10 @@
|
||||
"""Implements a FastAPI server to run the gradio interface."""
|
||||
"""Implements a FastAPI server to run the gradio interface. Note that some types in this
|
||||
module use the Optional/Union notation so that they work correctly with pydantic."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import inspect
|
||||
import io
|
||||
import json
|
||||
import mimetypes
|
||||
import os
|
||||
@ -36,7 +36,7 @@ from starlette.responses import RedirectResponse
|
||||
from starlette.websockets import WebSocketState
|
||||
|
||||
import gradio
|
||||
from gradio import encryptor, utils
|
||||
from gradio import utils
|
||||
from gradio.data_classes import PredictBody, ResetBody
|
||||
from gradio.documentation import document, set_documentation_group
|
||||
from gradio.exceptions import Error
|
||||
@ -97,9 +97,9 @@ class App(FastAPI):
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
self.tokens = None
|
||||
self.tokens = {}
|
||||
self.auth = None
|
||||
self.blocks: Optional[gradio.Blocks] = None
|
||||
self.blocks: gradio.Blocks | None = None
|
||||
self.state_holder = {}
|
||||
self.iterators = defaultdict(dict)
|
||||
self.lock = asyncio.Lock()
|
||||
@ -124,6 +124,11 @@ class App(FastAPI):
|
||||
self.favicon_path = blocks.favicon_path
|
||||
self.tokens = {}
|
||||
|
||||
def get_blocks(self) -> gradio.Blocks:
|
||||
if self.blocks is None:
|
||||
raise ValueError("No Blocks has been configured for this app.")
|
||||
return self.blocks
|
||||
|
||||
@staticmethod
|
||||
def create_app(blocks: gradio.Blocks) -> App:
|
||||
app = App(default_response_class=ORJSONResponse)
|
||||
@ -151,7 +156,7 @@ class App(FastAPI):
|
||||
status_code=status.HTTP_401_UNAUTHORIZED, detail="Not authenticated"
|
||||
)
|
||||
|
||||
async def ws_login_check(websocket: WebSocket) -> str:
|
||||
async def ws_login_check(websocket: WebSocket) -> Optional[str]:
|
||||
token = websocket.cookies.get("access-token")
|
||||
return token # token is returned to allow request in queue
|
||||
|
||||
@ -163,13 +168,15 @@ class App(FastAPI):
|
||||
|
||||
@app.get("/app_id")
|
||||
@app.get("/app_id/")
|
||||
def app_id(request: fastapi.Request) -> int:
|
||||
return {"app_id": app.blocks.app_id}
|
||||
def app_id(request: fastapi.Request) -> dict:
|
||||
return {"app_id": app.get_blocks().app_id}
|
||||
|
||||
@app.post("/login")
|
||||
@app.post("/login/")
|
||||
def login(form_data: OAuth2PasswordRequestForm = Depends()):
|
||||
username, password = form_data.username, form_data.password
|
||||
if app.auth is None:
|
||||
return RedirectResponse(url="/", status_code=status.HTTP_302_FOUND)
|
||||
if (
|
||||
not callable(app.auth)
|
||||
and username in app.auth
|
||||
@ -191,24 +198,25 @@ class App(FastAPI):
|
||||
@app.get("/", response_class=HTMLResponse)
|
||||
def main(request: fastapi.Request, user: str = Depends(get_current_user)):
|
||||
mimetypes.add_type("application/javascript", ".js")
|
||||
blocks = app.get_blocks()
|
||||
|
||||
if app.auth is None or not (user is None):
|
||||
config = app.blocks.config
|
||||
config = app.get_blocks().config
|
||||
else:
|
||||
config = {
|
||||
"auth_required": True,
|
||||
"auth_message": app.blocks.auth_message,
|
||||
"auth_message": blocks.auth_message,
|
||||
}
|
||||
|
||||
try:
|
||||
template = (
|
||||
"frontend/share.html" if app.blocks.share else "frontend/index.html"
|
||||
"frontend/share.html" if blocks.share else "frontend/index.html"
|
||||
)
|
||||
return templates.TemplateResponse(
|
||||
template, {"request": request, "config": config}
|
||||
)
|
||||
except TemplateNotFound:
|
||||
if app.blocks.share:
|
||||
if blocks.share:
|
||||
raise ValueError(
|
||||
"Did you install Gradio from source files? Share mode only "
|
||||
"works when Gradio is installed through the pip package."
|
||||
@ -222,7 +230,7 @@ class App(FastAPI):
|
||||
@app.get("/config/", dependencies=[Depends(login_check)])
|
||||
@app.get("/config", dependencies=[Depends(login_check)])
|
||||
def get_config():
|
||||
return app.blocks.config
|
||||
return app.get_blocks().config
|
||||
|
||||
@app.get("/static/{path:path}")
|
||||
def static_resource(path: str):
|
||||
@ -240,31 +248,20 @@ class App(FastAPI):
|
||||
|
||||
@app.get("/favicon.ico")
|
||||
async def favicon():
|
||||
if app.blocks.favicon_path is None:
|
||||
blocks = app.get_blocks()
|
||||
if blocks.favicon_path is None:
|
||||
return static_resource("img/logo.svg")
|
||||
else:
|
||||
return FileResponse(app.blocks.favicon_path)
|
||||
return FileResponse(blocks.favicon_path)
|
||||
|
||||
@app.get("/file={path:path}", dependencies=[Depends(login_check)])
|
||||
def file(path: str):
|
||||
blocks = app.get_blocks()
|
||||
if utils.validate_url(path):
|
||||
return RedirectResponse(url=path, status_code=status.HTTP_302_FOUND)
|
||||
if (
|
||||
app.blocks.encrypt
|
||||
and isinstance(app.blocks.examples, str)
|
||||
and path.startswith(app.blocks.examples)
|
||||
):
|
||||
with open(safe_join(app.cwd, path), "rb") as encrypted_file:
|
||||
encrypted_data = encrypted_file.read()
|
||||
file_data = encryptor.decrypt(app.blocks.encryption_key, encrypted_data)
|
||||
return FileResponse(
|
||||
io.BytesIO(file_data), attachment_filename=os.path.basename(path)
|
||||
)
|
||||
if Path(app.cwd).resolve() in Path(
|
||||
if Path(app.cwd).resolve() in Path(path).resolve().parents or Path(
|
||||
path
|
||||
).resolve().parents or os.path.abspath(path) in set().union(
|
||||
*app.blocks.temp_file_sets
|
||||
): # Need to use os.path.abspath in the second condition to be consistent with usage in TempFileManager
|
||||
).resolve() in set().union(*blocks.temp_file_sets):
|
||||
return FileResponse(
|
||||
Path(path).resolve(), headers={"Accept-Ranges": "bytes"}
|
||||
)
|
||||
@ -289,14 +286,15 @@ class App(FastAPI):
|
||||
|
||||
async def run_predict(
|
||||
body: PredictBody,
|
||||
request: Request,
|
||||
request: Request | List[Request],
|
||||
fn_index_inferred: int,
|
||||
username: str = Depends(get_current_user),
|
||||
):
|
||||
if hasattr(body, "session_hash"):
|
||||
if body.session_hash not in app.state_holder:
|
||||
app.state_holder[body.session_hash] = {
|
||||
_id: deepcopy(getattr(block, "value", None))
|
||||
for _id, block in app.blocks.blocks.items()
|
||||
for _id, block in app.get_blocks().blocks.items()
|
||||
if getattr(block, "stateful", False)
|
||||
}
|
||||
session_state = app.state_holder[body.session_hash]
|
||||
@ -315,12 +313,12 @@ class App(FastAPI):
|
||||
event_id = getattr(body, "event_id", None)
|
||||
raw_input = body.data
|
||||
fn_index = body.fn_index
|
||||
batch = app.blocks.dependencies[fn_index]["batch"]
|
||||
batch = app.get_blocks().dependencies[fn_index_inferred]["batch"]
|
||||
if not (body.batched) and batch:
|
||||
raw_input = [raw_input]
|
||||
try:
|
||||
output = await app.blocks.process_api(
|
||||
fn_index=fn_index,
|
||||
output = await app.get_blocks().process_api(
|
||||
fn_index=fn_index_inferred,
|
||||
inputs=raw_input,
|
||||
request=request,
|
||||
state=session_state,
|
||||
@ -336,7 +334,7 @@ class App(FastAPI):
|
||||
if isinstance(output, Error):
|
||||
raise output
|
||||
except BaseException as error:
|
||||
show_error = app.blocks.show_error or isinstance(error, Error)
|
||||
show_error = app.get_blocks().show_error or isinstance(error, Error)
|
||||
traceback.print_exc()
|
||||
return JSONResponse(
|
||||
content={"error": str(error) if show_error else None},
|
||||
@ -358,20 +356,23 @@ class App(FastAPI):
|
||||
request: fastapi.Request,
|
||||
username: str = Depends(get_current_user),
|
||||
):
|
||||
fn_index_inferred = None
|
||||
if body.fn_index is None:
|
||||
for i, fn in enumerate(app.blocks.dependencies):
|
||||
for i, fn in enumerate(app.get_blocks().dependencies):
|
||||
if fn["api_name"] == api_name:
|
||||
body.fn_index = i
|
||||
fn_index_inferred = i
|
||||
break
|
||||
if body.fn_index is None:
|
||||
if fn_index_inferred is None:
|
||||
return JSONResponse(
|
||||
content={
|
||||
"error": f"This app has no endpoint /api/{api_name}/."
|
||||
},
|
||||
status_code=500,
|
||||
)
|
||||
if not app.blocks.api_open and app.blocks.queue_enabled_for_fn(
|
||||
body.fn_index
|
||||
else:
|
||||
fn_index_inferred = body.fn_index
|
||||
if not app.get_blocks().api_open and app.get_blocks().queue_enabled_for_fn(
|
||||
fn_index_inferred
|
||||
):
|
||||
if f"Bearer {app.queue_token}" != request.headers.get("Authorization"):
|
||||
raise HTTPException(
|
||||
@ -381,29 +382,36 @@ class App(FastAPI):
|
||||
|
||||
# If this fn_index cancels jobs, then the only input we need is the
|
||||
# current session hash
|
||||
if app.blocks.dependencies[body.fn_index]["cancels"]:
|
||||
if app.get_blocks().dependencies[fn_index_inferred]["cancels"]:
|
||||
body.data = [body.session_hash]
|
||||
if body.request:
|
||||
if body.batched:
|
||||
request = [Request(**req) for req in body.request]
|
||||
gr_request = [Request(**req) for req in body.request]
|
||||
else:
|
||||
request = Request(**body.request)
|
||||
assert isinstance(body.request, dict)
|
||||
gr_request = Request(**body.request)
|
||||
else:
|
||||
request = Request(request)
|
||||
result = await run_predict(body=body, username=username, request=request)
|
||||
gr_request = Request(request)
|
||||
result = await run_predict(
|
||||
body=body,
|
||||
fn_index_inferred=fn_index_inferred,
|
||||
username=username,
|
||||
request=gr_request,
|
||||
)
|
||||
return result
|
||||
|
||||
@app.websocket("/queue/join")
|
||||
async def join_queue(
|
||||
websocket: WebSocket,
|
||||
token: str = Depends(ws_login_check),
|
||||
token: Optional[str] = Depends(ws_login_check),
|
||||
):
|
||||
blocks = app.get_blocks()
|
||||
if app.auth is not None and token is None:
|
||||
await websocket.close(code=status.WS_1008_POLICY_VIOLATION)
|
||||
return
|
||||
if app.blocks._queue.server_path is None:
|
||||
if blocks._queue.server_path is None:
|
||||
app_url = get_server_url_from_ws_url(str(websocket.url))
|
||||
app.blocks._queue.set_url(app_url)
|
||||
blocks._queue.set_url(app_url)
|
||||
await websocket.accept()
|
||||
# In order to cancel jobs, we need the session_hash and fn_index
|
||||
# to create a unique id for each job
|
||||
@ -414,27 +422,26 @@ class App(FastAPI):
|
||||
)
|
||||
# set the token into Event to allow using the same token for call_prediction
|
||||
event.token = token
|
||||
event.session_hash = session_info["session_hash"]
|
||||
|
||||
# Continuous events are not put in the queue so that they do not
|
||||
# occupy the queue's resource as they are expected to run forever
|
||||
if app.blocks.dependencies[event.fn_index].get("every", 0):
|
||||
await cancel_tasks([f"{event.session_hash}_{event.fn_index}"])
|
||||
await app.blocks._queue.reset_iterators(
|
||||
event.session_hash, event.fn_index
|
||||
)
|
||||
if blocks.dependencies[event.fn_index].get("every", 0):
|
||||
await cancel_tasks(set([f"{event.session_hash}_{event.fn_index}"]))
|
||||
await blocks._queue.reset_iterators(event.session_hash, event.fn_index)
|
||||
task = run_coro_in_background(
|
||||
app.blocks._queue.process_events, [event], False
|
||||
blocks._queue.process_events, [event], False
|
||||
)
|
||||
set_task_name(task, event.session_hash, event.fn_index, batch=False)
|
||||
else:
|
||||
rank = app.blocks._queue.push(event)
|
||||
rank = blocks._queue.push(event)
|
||||
|
||||
if rank is None:
|
||||
await app.blocks._queue.send_message(event, {"msg": "queue_full"})
|
||||
await blocks._queue.send_message(event, {"msg": "queue_full"})
|
||||
await event.disconnect()
|
||||
return
|
||||
estimation = app.blocks._queue.get_estimation()
|
||||
await app.blocks._queue.send_estimation(event, estimation, rank)
|
||||
estimation = blocks._queue.get_estimation()
|
||||
await blocks._queue.send_estimation(event, estimation, rank)
|
||||
while True:
|
||||
await asyncio.sleep(60)
|
||||
if websocket.application_state == WebSocketState.DISCONNECTED:
|
||||
@ -446,19 +453,19 @@ class App(FastAPI):
|
||||
response_model=Estimation,
|
||||
)
|
||||
async def get_queue_status():
|
||||
return app.blocks._queue.get_estimation()
|
||||
return app.get_blocks()._queue.get_estimation()
|
||||
|
||||
@app.get("/startup-events")
|
||||
async def startup_events():
|
||||
if not app.startup_events_triggered:
|
||||
app.blocks.startup_events()
|
||||
app.get_blocks().startup_events()
|
||||
app.startup_events_triggered = True
|
||||
return True
|
||||
return False
|
||||
|
||||
@app.get("/robots.txt", response_class=PlainTextResponse)
|
||||
def robots_txt():
|
||||
if app.blocks.share:
|
||||
if app.get_blocks().share:
|
||||
return "User-agent: *\nDisallow: /"
|
||||
else:
|
||||
return "User-agent: *\nDisallow: "
|
||||
@ -471,7 +478,7 @@ class App(FastAPI):
|
||||
########
|
||||
|
||||
|
||||
def safe_join(directory: str, path: str) -> Optional[str]:
|
||||
def safe_join(directory: str, path: str) -> str | None:
|
||||
"""Safely path to a base directory to avoid escaping the base directory.
|
||||
Borrowed from: werkzeug.security.safe_join"""
|
||||
_os_alt_seps: List[str] = list(
|
||||
@ -480,6 +487,8 @@ def safe_join(directory: str, path: str) -> Optional[str]:
|
||||
|
||||
if path != "":
|
||||
filename = posixpath.normpath(path)
|
||||
else:
|
||||
return directory
|
||||
|
||||
if (
|
||||
any(sep in filename for sep in _os_alt_seps)
|
||||
@ -495,7 +504,7 @@ def get_types(cls_set: List[Type]):
|
||||
docset = []
|
||||
types = []
|
||||
for cls in cls_set:
|
||||
doc = inspect.getdoc(cls)
|
||||
doc = inspect.getdoc(cls) or ""
|
||||
doc_lines = doc.split("\n")
|
||||
for line in doc_lines:
|
||||
if "value (" in line:
|
||||
@ -505,10 +514,10 @@ def get_types(cls_set: List[Type]):
|
||||
|
||||
|
||||
def get_server_url_from_ws_url(ws_url: str):
|
||||
ws_url = urlparse(ws_url)
|
||||
scheme = "http" if ws_url.scheme == "ws" else "https"
|
||||
port = f":{ws_url.port}" if ws_url.port else ""
|
||||
return f"{scheme}://{ws_url.hostname}{port}{ws_url.path.replace('queue/join', '')}"
|
||||
ws_url_parsed = urlparse(ws_url)
|
||||
scheme = "http" if ws_url_parsed.scheme == "ws" else "https"
|
||||
port = f":{ws_url_parsed.port}" if ws_url_parsed.port else ""
|
||||
return f"{scheme}://{ws_url_parsed.hostname}{port}{ws_url_parsed.path.replace('queue/join', '')}"
|
||||
|
||||
|
||||
set_documentation_group("routes")
|
||||
@ -553,7 +562,7 @@ class Request:
|
||||
Parameters:
|
||||
request: A fastapi.Request
|
||||
"""
|
||||
self.request: fastapi.Request = request
|
||||
self.request = request
|
||||
self.kwargs: Dict = kwargs
|
||||
|
||||
def dict_to_obj(self, d):
|
||||
@ -578,7 +587,7 @@ def mount_gradio_app(
|
||||
app: fastapi.FastAPI,
|
||||
blocks: gradio.Blocks,
|
||||
path: str,
|
||||
gradio_api_url: Optional[str] = None,
|
||||
gradio_api_url: str | None = None,
|
||||
) -> fastapi.FastAPI:
|
||||
"""Mount a gradio.Blocks to an existing FastAPI application.
|
||||
|
||||
@ -604,10 +613,10 @@ def mount_gradio_app(
|
||||
|
||||
@app.on_event("startup")
|
||||
async def start_queue():
|
||||
if gradio_app.blocks.enable_queue:
|
||||
if gradio_app.get_blocks().enable_queue:
|
||||
if gradio_api_url:
|
||||
gradio_app.blocks._queue.set_url(gradio_api_url)
|
||||
gradio_app.blocks.startup_events()
|
||||
gradio_app.get_blocks()._queue.set_url(gradio_api_url)
|
||||
gradio_app.get_blocks().startup_events()
|
||||
|
||||
app.mount(path, gradio_app)
|
||||
return app
|
||||
|
@ -3,6 +3,7 @@ import os
|
||||
import platform
|
||||
import re
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
VERSION = "0.1"
|
||||
@ -26,11 +27,11 @@ class Tunnel:
|
||||
|
||||
# Check if the file exist
|
||||
binary_name = f"frpc_{platform.system().lower()}_{machine.lower()}"
|
||||
binary_path = os.path.join(os.path.dirname(__file__), binary_name)
|
||||
binary_path = str(Path(__file__).parent / binary_name)
|
||||
|
||||
extension = ".exe" if os.name == "nt" else ""
|
||||
|
||||
if not os.path.exists(binary_path):
|
||||
if not Path(binary_path).exists():
|
||||
import stat
|
||||
|
||||
import requests
|
||||
@ -91,8 +92,14 @@ class Tunnel:
|
||||
atexit.register(self.kill)
|
||||
url = ""
|
||||
while url == "":
|
||||
if self.proc.stdout is None:
|
||||
continue
|
||||
line = self.proc.stdout.readline()
|
||||
line = line.decode("utf-8")
|
||||
if "start proxy success" in line:
|
||||
url = re.search("start proxy success: (.+)\n", line).group(1)
|
||||
result = re.search("start proxy success: (.+)\n", line)
|
||||
if result is None:
|
||||
raise ValueError("Could not create share URL")
|
||||
else:
|
||||
url = result.group(1)
|
||||
return url
|
||||
|
@ -6,4 +6,4 @@ pip_required
|
||||
pip install --upgrade pip
|
||||
pip install pyright
|
||||
cd gradio
|
||||
pyright blocks.py components.py context.py data_classes.py deprecation.py documentation.py encryptor.py events.py examples.py exceptions.py external.py external_utils.py serializing.py layouts.py flagging.py interface.py utils.py templates.py
|
||||
pyright blocks.py components.py context.py data_classes.py deprecation.py documentation.py encryptor.py events.py examples.py exceptions.py external.py external_utils.py serializing.py layouts.py flagging.py interface.py mix.py networking.py pipelines.py processing_utils.py queueing.py reload.py routes.py serializing.py strings.py tunneling.py utils.py templates.py
|
||||
|
@ -242,13 +242,11 @@ class TestAuthenticatedRoutes:
|
||||
response = client.post(
|
||||
"/login",
|
||||
data=dict(username="test", password="correct_password"),
|
||||
follow_redirects=False,
|
||||
)
|
||||
assert response.status_code == 302
|
||||
assert response.status_code == 200
|
||||
response = client.post(
|
||||
"/login",
|
||||
data=dict(username="test", password="incorrect_password"),
|
||||
follow_redirects=False,
|
||||
)
|
||||
assert response.status_code == 400
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user