mirror of
https://github.com/gradio-app/gradio.git
synced 2025-03-31 12:20:26 +08:00
Update ruff to 0.1.13, enable more rules, fix issues (#7061)
* add changeset * Update ruff to version 0.1.13 * Correct ruff target version * Enable more Ruff rules and fix issues * Enable ARG and fix issues * Enable PL lints, fix issues --------- Co-authored-by: gradio-pr-bot <gradio-pr-bot@users.noreply.github.com>
This commit is contained in:
parent
aeb0541423
commit
05d8a3c803
8
.changeset/salty-olives-clean.md
Normal file
8
.changeset/salty-olives-clean.md
Normal file
@ -0,0 +1,8 @@
|
||||
---
|
||||
"@gradio/preview": minor
|
||||
"@gradio/wasm": minor
|
||||
"gradio": minor
|
||||
"gradio_client": minor
|
||||
---
|
||||
|
||||
feat:Update ruff to 0.1.13, enable more rules, fix issues
|
@ -580,9 +580,8 @@ class Client:
|
||||
# When loading from json, the fn_indices are read as strings
|
||||
# because json keys can only be strings
|
||||
human_info += self._render_endpoints_info(int(fn_index), endpoint_info)
|
||||
else:
|
||||
if num_unnamed_endpoints > 0:
|
||||
human_info += f"\nUnnamed API endpoints: {num_unnamed_endpoints}, to view, run Client.view_api(all_endpoints=True)\n"
|
||||
elif num_unnamed_endpoints > 0:
|
||||
human_info += f"\nUnnamed API endpoints: {num_unnamed_endpoints}, to view, run Client.view_api(all_endpoints=True)\n"
|
||||
|
||||
if print_info:
|
||||
print(human_info)
|
||||
@ -1003,7 +1002,7 @@ class Endpoint:
|
||||
result = utils.synchronize_async(
|
||||
self._sse_fn_v0, data, hash_data, helper
|
||||
)
|
||||
elif self.protocol == "sse_v1" or self.protocol == "sse_v2":
|
||||
elif self.protocol in ("sse_v1", "sse_v2"):
|
||||
event_id = utils.synchronize_async(
|
||||
self.client.send_data, data, hash_data
|
||||
)
|
||||
@ -1220,7 +1219,7 @@ class Endpoint:
|
||||
class EndpointV3Compatibility:
|
||||
"""Endpoint class for connecting to v3 endpoints. Backwards compatibility."""
|
||||
|
||||
def __init__(self, client: Client, fn_index: int, dependency: dict, *args):
|
||||
def __init__(self, client: Client, fn_index: int, dependency: dict, *_args):
|
||||
self.client: Client = client
|
||||
self.fn_index = fn_index
|
||||
self.dependency = dependency
|
||||
@ -1673,26 +1672,25 @@ class Job(Future):
|
||||
eta=None,
|
||||
progress_data=None,
|
||||
)
|
||||
elif not self.communicator:
|
||||
return StatusUpdate(
|
||||
code=Status.PROCESSING,
|
||||
rank=0,
|
||||
queue_size=None,
|
||||
success=None,
|
||||
time=time,
|
||||
eta=None,
|
||||
progress_data=None,
|
||||
)
|
||||
else:
|
||||
if not self.communicator:
|
||||
return StatusUpdate(
|
||||
code=Status.PROCESSING,
|
||||
rank=0,
|
||||
queue_size=None,
|
||||
success=None,
|
||||
time=time,
|
||||
eta=None,
|
||||
progress_data=None,
|
||||
)
|
||||
else:
|
||||
with self.communicator.lock:
|
||||
eta = self.communicator.job.latest_status.eta
|
||||
if self.verbose and self.space_id and eta and eta > 30:
|
||||
print(
|
||||
f"Due to heavy traffic on this app, the prediction will take approximately {int(eta)} seconds."
|
||||
f"For faster predictions without waiting in queue, you may duplicate the space using: Client.duplicate({self.space_id})"
|
||||
)
|
||||
return self.communicator.job.latest_status
|
||||
with self.communicator.lock:
|
||||
eta = self.communicator.job.latest_status.eta
|
||||
if self.verbose and self.space_id and eta and eta > 30:
|
||||
print(
|
||||
f"Due to heavy traffic on this app, the prediction will take approximately {int(eta)} seconds."
|
||||
f"For faster predictions without waiting in queue, you may duplicate the space using: Client.duplicate({self.space_id})"
|
||||
)
|
||||
return self.communicator.job.latest_status
|
||||
|
||||
def __getattr__(self, name):
|
||||
"""Forwards any properties to the Future class."""
|
||||
|
@ -11,7 +11,7 @@ documentation_group = None
|
||||
|
||||
|
||||
def set_documentation_group(m):
|
||||
global documentation_group
|
||||
global documentation_group # noqa: PLW0603
|
||||
documentation_group = m
|
||||
if m not in classes_to_document:
|
||||
classes_to_document[m] = []
|
||||
@ -57,7 +57,6 @@ def document(*fns, inherit=False):
|
||||
functions = list(fns)
|
||||
if hasattr(cls, "EVENTS"):
|
||||
functions += cls.EVENTS
|
||||
global documentation_group
|
||||
if inherit:
|
||||
classes_inherit_documentation[cls] = None
|
||||
classes_to_document[documentation_group].append((cls, functions))
|
||||
@ -177,15 +176,14 @@ def document_cls(cls):
|
||||
tag = line[: line.index(":")].lower()
|
||||
value = line[line.index(":") + 2 :]
|
||||
tags[tag] = value
|
||||
elif mode == "description":
|
||||
description_lines.append(line if line.strip() else "<br>")
|
||||
else:
|
||||
if mode == "description":
|
||||
description_lines.append(line if line.strip() else "<br>")
|
||||
else:
|
||||
if not (line.startswith(" ") or not line.strip()):
|
||||
raise SyntaxError(
|
||||
f"Documentation format for {cls.__name__} has format error in line: {line}"
|
||||
)
|
||||
tags[mode].append(line[4:])
|
||||
if not (line.startswith(" ") or not line.strip()):
|
||||
raise SyntaxError(
|
||||
f"Documentation format for {cls.__name__} has format error in line: {line}"
|
||||
)
|
||||
tags[mode].append(line[4:])
|
||||
if "example" in tags:
|
||||
example = "\n".join(tags["example"])
|
||||
del tags["example"]
|
||||
|
@ -141,19 +141,19 @@ if not DISCORD_TOKEN:
|
||||
## You have not specified a DISCORD_TOKEN, which means you have not created a bot account. Please follow these steps:
|
||||
|
||||
### 1. Go to https://discord.com/developers/applications and click 'New Application'
|
||||
|
||||
|
||||
### 2. Give your bot a name 🤖
|
||||
|
||||

|
||||
|
||||
|
||||
## 3. In Settings > Bot, click the 'Reset Token' button to get a new token. Write it down and keep it safe 🔐
|
||||
|
||||
|
||||

|
||||
|
||||
|
||||
## 4. Optionally make the bot public if you want anyone to be able to add it to their servers
|
||||
|
||||
|
||||
## 5. Scroll down and enable 'Message Content Intent' under 'Priviledged Gateway Intents'
|
||||
|
||||
|
||||

|
||||
|
||||
## 6. Save your changes!
|
||||
@ -164,20 +164,20 @@ else:
|
||||
permissions = Permissions(326417525824)
|
||||
url = oauth_url(bot.user.id, permissions=permissions)
|
||||
welcome_message = f"""
|
||||
## Add this bot to your server by clicking this link:
|
||||
|
||||
## Add this bot to your server by clicking this link:
|
||||
|
||||
{url}
|
||||
|
||||
## How to use it?
|
||||
|
||||
The bot can be triggered via `/<<command-name>>` followed by your text prompt.
|
||||
|
||||
|
||||
This will create a thread with the bot's response to your text prompt.
|
||||
You can reply in the thread (without `/<<command-name>>`) to continue the conversation.
|
||||
In the thread, the bot will only reply to the original author of the command.
|
||||
|
||||
⚠️ Note ⚠️: Please make sure this bot's command does have the same name as another command in your server.
|
||||
|
||||
|
||||
⚠️ Note ⚠️: Bot commands do not work in DMs with the bot as of now.
|
||||
"""
|
||||
|
||||
|
@ -131,7 +131,7 @@ def cancel_from_client_demo():
|
||||
|
||||
@pytest.fixture
|
||||
def sentiment_classification_demo():
|
||||
def classifier(text):
|
||||
def classifier(text): # noqa: ARG001
|
||||
time.sleep(1)
|
||||
return {label: random.random() for label in ["POSITIVE", "NEGATIVE", "NEUTRAL"]}
|
||||
|
||||
@ -234,7 +234,7 @@ def count_generator_demo_exception():
|
||||
@pytest.fixture
|
||||
def file_io_demo():
|
||||
demo = gr.Interface(
|
||||
lambda x: print("foox"),
|
||||
lambda _: print("foox"),
|
||||
[gr.File(file_count="multiple"), "file"],
|
||||
[gr.File(file_count="multiple"), "file"],
|
||||
)
|
||||
@ -385,7 +385,7 @@ def gradio_temp_dir(monkeypatch, tmp_path):
|
||||
|
||||
@pytest.fixture
|
||||
def long_response_with_info():
|
||||
def long_response(x):
|
||||
def long_response(_):
|
||||
gr.Info("Beginning long response")
|
||||
time.sleep(17)
|
||||
gr.Info("Done!")
|
||||
|
@ -1,6 +1,6 @@
|
||||
pytest-asyncio
|
||||
pytest==7.1.2
|
||||
ruff==0.1.7
|
||||
ruff==0.1.13
|
||||
pyright==1.1.327
|
||||
gradio
|
||||
pydub==0.25.1
|
||||
|
@ -1,12 +1,10 @@
|
||||
import json
|
||||
|
||||
import gradio._simple_templates
|
||||
import gradio.components as components
|
||||
import gradio.image_utils
|
||||
import gradio.layouts as layouts
|
||||
import gradio.processing_utils
|
||||
import gradio.templates
|
||||
import gradio.themes as themes
|
||||
from gradio import components, layouts, themes
|
||||
from gradio.blocks import Blocks
|
||||
from gradio.chat_interface import ChatInterface
|
||||
from gradio.cli import deploy
|
||||
|
@ -22,7 +22,7 @@ try:
|
||||
from pyodide.http import pyfetch as pyodide_pyfetch # type: ignore
|
||||
except ImportError:
|
||||
|
||||
async def pyodide_pyfetch(*args, **kwargs):
|
||||
async def pyodide_pyfetch(*_args, **_kwargs):
|
||||
raise NotImplementedError(
|
||||
"pyodide.http.pyfetch is not available in this environment."
|
||||
)
|
||||
|
@ -582,7 +582,7 @@ class Blocks(BlockContext, BlocksEvents, metaclass=BlocksMeta):
|
||||
self.space_id = utils.get_space()
|
||||
self.favicon_path = None
|
||||
self.auth = None
|
||||
self.dev_mode = bool(os.getenv("GRADIO_WATCH_DIRS", False))
|
||||
self.dev_mode = bool(os.getenv("GRADIO_WATCH_DIRS", ""))
|
||||
self.app_id = random.getrandbits(64)
|
||||
self.temp_file_sets = []
|
||||
self.title = title
|
||||
@ -872,11 +872,10 @@ class Blocks(BlockContext, BlocksEvents, metaclass=BlocksMeta):
|
||||
|
||||
if isinstance(outputs, set):
|
||||
outputs = sorted(outputs, key=lambda x: x._id)
|
||||
else:
|
||||
if outputs is None:
|
||||
outputs = []
|
||||
elif not isinstance(outputs, list):
|
||||
outputs = [outputs]
|
||||
elif outputs is None:
|
||||
outputs = []
|
||||
elif not isinstance(outputs, list):
|
||||
outputs = [outputs]
|
||||
|
||||
if fn is not None and not cancels:
|
||||
check_function_inputs_match(fn, inputs, inputs_as_dict)
|
||||
@ -2173,7 +2172,7 @@ Received outputs:
|
||||
analytics.launched_analytics(self, data)
|
||||
|
||||
# Block main thread if debug==True
|
||||
if debug or int(os.getenv("GRADIO_DEBUG", 0)) == 1 and not wasm_utils.IS_WASM:
|
||||
if debug or int(os.getenv("GRADIO_DEBUG", "0")) == 1 and not wasm_utils.IS_WASM:
|
||||
self.block_thread()
|
||||
# Block main thread if running in a script to stop script from exiting
|
||||
is_in_interactive_mode = bool(getattr(sys, "ps1", sys.flags.interactive))
|
||||
|
@ -1,4 +1,4 @@
|
||||
""" This file is the part of 'gradio/cli.py' for printing the environment info
|
||||
""" This file is the part of 'gradio/cli.py' for printing the environment info
|
||||
for the cli command 'gradio environment'
|
||||
"""
|
||||
import platform
|
||||
|
@ -246,14 +246,17 @@ def delete_contents(directory: str | Path) -> None:
|
||||
|
||||
|
||||
def _create_frontend(
|
||||
name: str, component: ComponentFiles, directory: Path, package_name: str
|
||||
name: str, # noqa: ARG001
|
||||
component: ComponentFiles,
|
||||
directory: Path,
|
||||
package_name: str,
|
||||
):
|
||||
frontend = directory / "frontend"
|
||||
frontend.mkdir(exist_ok=True)
|
||||
|
||||
p = Path(inspect.getfile(gradio)).parent
|
||||
|
||||
def ignore(s, names):
|
||||
def ignore(_src, names):
|
||||
ignored = []
|
||||
for n in names:
|
||||
if (
|
||||
|
@ -174,7 +174,7 @@ def get_container_name(arg):
|
||||
return str(arg)
|
||||
|
||||
|
||||
def format_type(_type: list[typing.Any], current=None):
|
||||
def format_type(_type: list[typing.Any]):
|
||||
"""Pretty formats a possibly nested type hint."""
|
||||
|
||||
s = []
|
||||
@ -187,18 +187,18 @@ def format_type(_type: list[typing.Any], current=None):
|
||||
elif isinstance(t, list):
|
||||
if len(t) == 0:
|
||||
continue
|
||||
s.append(f"{format_type(t, _current)}")
|
||||
s.append(f"{format_type(t)}")
|
||||
else:
|
||||
s.append(t)
|
||||
if len(s) == 0:
|
||||
return _current
|
||||
elif _current == "Literal" or _current == "Union":
|
||||
elif _current in ("Literal", "Union"):
|
||||
return "| ".join(s)
|
||||
else:
|
||||
return f"{_current}[{','.join(s)}]"
|
||||
|
||||
|
||||
def get_type_hints(param, module, ignore=None):
|
||||
def get_type_hints(param, module):
|
||||
"""Gets the type hints for a parameter."""
|
||||
|
||||
def extract_args(
|
||||
@ -268,7 +268,7 @@ def get_type_hints(param, module, ignore=None):
|
||||
|
||||
if len(new_args) > 0:
|
||||
arg_names.append(new_args)
|
||||
else:
|
||||
else: # noqa: PLR5501
|
||||
if append:
|
||||
arg_names.append(get_param_name(arg))
|
||||
return arg_names
|
||||
@ -313,11 +313,7 @@ def extract_docstrings(module):
|
||||
for member_name, member in inspect.getmembers(obj):
|
||||
if inspect.ismethod(member) or inspect.isfunction(member):
|
||||
# we are are only interested in these methods
|
||||
if (
|
||||
member_name != "__init__"
|
||||
and member_name != "preprocess"
|
||||
and member_name != "postprocess"
|
||||
):
|
||||
if member_name not in ("__init__", "preprocess", "postprocess"):
|
||||
continue
|
||||
|
||||
docs[name]["members"][member_name] = {}
|
||||
@ -385,7 +381,7 @@ def extract_docstrings(module):
|
||||
] = docstring
|
||||
|
||||
# We just want to normalise the arg name to 'value' for the preprocess and postprocess methods
|
||||
if member_name == "postprocess" or member_name == "preprocess":
|
||||
if member_name in ("postprocess", "preprocess"):
|
||||
docs[name]["members"][member_name][
|
||||
"value"
|
||||
] = find_first_non_return_key(
|
||||
@ -468,7 +464,7 @@ def make_js(
|
||||
}})
|
||||
}}
|
||||
}})
|
||||
|
||||
|
||||
Object.entries(refs).forEach(([key, refs]) => {{
|
||||
if (refs.length > 0) {{
|
||||
const el = document.querySelector(`.${{key}}`);
|
||||
@ -575,8 +571,8 @@ def make_user_fn(
|
||||
|
||||
The impact on the users predict function varies depending on whether the component is used as an input or output for an event (or both).
|
||||
|
||||
- When used as an Input, the component only impacts the input signature of the user function.
|
||||
- When used as an output, the component only impacts the return signature of the user function.
|
||||
- When used as an Input, the component only impacts the input signature of the user function.
|
||||
- When used as an output, the component only impacts the return signature of the user function.
|
||||
|
||||
The code snippet below is accurate in cases where the component is used as both an input and an output.
|
||||
|
||||
@ -634,8 +630,8 @@ def make_user_fn_markdown(
|
||||
|
||||
The impact on the users predict function varies depending on whether the component is used as an input or output for an event (or both).
|
||||
|
||||
- When used as an Input, the component only impacts the input signature of the user function.
|
||||
- When used as an output, the component only impacts the return signature of the user function.
|
||||
- When used as an Input, the component only impacts the input signature of the user function.
|
||||
- When used as an output, the component only impacts the return signature of the user function.
|
||||
|
||||
The code snippet below is accurate in cases where the component is used as both an input and an output.
|
||||
|
||||
@ -835,7 +831,7 @@ def make_space(
|
||||
"""The demo must be launched using `if __name__ == '__main__'`, otherwise the docs page will not function correctly.
|
||||
|
||||
To fix this error, launch the demo inside of an if statement like this:
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
demo.launch()
|
||||
|
||||
@ -852,7 +848,7 @@ import os
|
||||
|
||||
source += f"""
|
||||
_docs = {docs}
|
||||
|
||||
|
||||
abs_path = os.path.join(os.path.dirname(__file__), "css.css")
|
||||
|
||||
with gr.Blocks(
|
||||
@ -914,8 +910,8 @@ def make_markdown(
|
||||
{description}
|
||||
|
||||
## Installation
|
||||
|
||||
```bash
|
||||
|
||||
```bash
|
||||
pip install {name}
|
||||
```
|
||||
|
||||
|
@ -116,7 +116,9 @@ def _build(
|
||||
"--mode",
|
||||
"build",
|
||||
]
|
||||
pipe = subprocess.run(node_cmds, capture_output=True, text=True)
|
||||
pipe = subprocess.run(
|
||||
node_cmds, capture_output=True, text=True, check=False
|
||||
)
|
||||
if pipe.returncode != 0:
|
||||
live.update(":red_square: Build failed!")
|
||||
live.update(pipe.stderr)
|
||||
@ -127,7 +129,7 @@ def _build(
|
||||
|
||||
cmds = [shutil.which("python"), "-m", "build", str(name)]
|
||||
live.update(f":construction_worker: Building... [grey37]({' '.join(cmds)})[/]")
|
||||
pipe = subprocess.run(cmds, capture_output=True, text=True)
|
||||
pipe = subprocess.run(cmds, capture_output=True, text=True, check=False)
|
||||
if pipe.returncode != 0:
|
||||
live.update(":red_square: Build failed!")
|
||||
live.update(pipe.stderr)
|
||||
|
@ -29,7 +29,7 @@ def _install_command(directory: Path, live: LivePanelDisplay, npm_install: str):
|
||||
live.update(
|
||||
f":construction_worker: Installing python... [grey37]({escape(' '.join(cmds))})[/]"
|
||||
)
|
||||
pipe = subprocess.run(cmds, capture_output=True, text=True)
|
||||
pipe = subprocess.run(cmds, capture_output=True, text=True, check=False)
|
||||
|
||||
if pipe.returncode != 0:
|
||||
live.update(":red_square: Python installation [bold][red]failed[/][/]")
|
||||
@ -41,7 +41,9 @@ def _install_command(directory: Path, live: LivePanelDisplay, npm_install: str):
|
||||
f":construction_worker: Installing javascript... [grey37]({npm_install})[/]"
|
||||
)
|
||||
with set_directory(directory / "frontend"):
|
||||
pipe = subprocess.run(npm_install.split(), capture_output=True, text=True)
|
||||
pipe = subprocess.run(
|
||||
npm_install.split(), capture_output=True, text=True, check=False
|
||||
)
|
||||
if pipe.returncode != 0:
|
||||
live.update(":red_square: NPM install [bold][red]failed[/][/]")
|
||||
live.update(pipe.stdout)
|
||||
|
@ -43,7 +43,7 @@ COPY --link --chown=1000 . .
|
||||
|
||||
RUN mkdir -p /tmp/cache/
|
||||
RUN chmod a+rwx -R /tmp/cache/
|
||||
ENV TRANSFORMERS_CACHE=/tmp/cache/
|
||||
ENV TRANSFORMERS_CACHE=/tmp/cache/
|
||||
|
||||
RUN pip install --no-cache-dir -r requirements.txt
|
||||
|
||||
@ -58,7 +58,7 @@ CMD ["python", "{demo}"]
|
||||
"""
|
||||
|
||||
|
||||
def _ignore(s, names):
|
||||
def _ignore(_src, names):
|
||||
ignored = []
|
||||
for n in names:
|
||||
if "__pycache__" in n or n.startswith("dist") or n.startswith("node_modules"):
|
||||
@ -91,9 +91,6 @@ def _publish(
|
||||
Path,
|
||||
Argument(help=f"Path to the wheel directory. Default is {Path('.') / 'dist'}"),
|
||||
] = Path(".") / "dist",
|
||||
bump_version: Annotated[
|
||||
bool, Option(help="Whether to bump the version number.")
|
||||
] = True,
|
||||
upload_pypi: Annotated[bool, Option(help="Whether to upload to PyPI.")] = True,
|
||||
pypi_username: Annotated[str, Option(help="The username for PyPI.")] = "",
|
||||
pypi_password: Annotated[str, Option(help="The password for PyPI.")] = "",
|
||||
|
@ -289,7 +289,7 @@ class Component(ComponentBase, Block):
|
||||
def read_from_flag(
|
||||
self,
|
||||
payload: Any,
|
||||
flag_dir: str | Path | None = None,
|
||||
flag_dir: str | Path | None = None, # noqa: ARG002
|
||||
):
|
||||
"""
|
||||
Convert the data from the csv or jsonl file into the component state.
|
||||
|
@ -132,16 +132,14 @@ class File(Component):
|
||||
"""
|
||||
if payload is None:
|
||||
return None
|
||||
|
||||
if self.file_count == "single":
|
||||
if isinstance(payload, ListFiles):
|
||||
return self._process_single_file(payload[0])
|
||||
else:
|
||||
return self._process_single_file(payload)
|
||||
else:
|
||||
if isinstance(payload, ListFiles):
|
||||
return [self._process_single_file(f) for f in payload] # type: ignore
|
||||
else:
|
||||
return [self._process_single_file(payload)] # type: ignore
|
||||
return self._process_single_file(payload)
|
||||
if isinstance(payload, ListFiles):
|
||||
return [self._process_single_file(f) for f in payload] # type: ignore
|
||||
return [self._process_single_file(payload)] # type: ignore
|
||||
|
||||
def postprocess(self, value: str | list[str] | None) -> ListFiles | FileData | None:
|
||||
"""
|
||||
|
@ -11,8 +11,7 @@ import PIL.Image
|
||||
from gradio_client.documentation import document, set_documentation_group
|
||||
from PIL import ImageOps
|
||||
|
||||
import gradio.image_utils as image_utils
|
||||
from gradio import utils
|
||||
from gradio import image_utils, utils
|
||||
from gradio.components.base import Component, StreamingInput
|
||||
from gradio.data_classes import FileData
|
||||
from gradio.events import Events
|
||||
|
@ -11,8 +11,7 @@ import numpy as np
|
||||
import PIL.Image
|
||||
from gradio_client.documentation import document, set_documentation_group
|
||||
|
||||
import gradio.image_utils as image_utils
|
||||
from gradio import utils
|
||||
from gradio import image_utils, utils
|
||||
from gradio.components.base import Component
|
||||
from gradio.data_classes import FileData, GradioModel
|
||||
from gradio.events import Events
|
||||
|
@ -93,10 +93,14 @@ class JSON(Component):
|
||||
def example_inputs(self) -> Any:
|
||||
return {"foo": "bar"}
|
||||
|
||||
def flag(self, payload: Any, flag_dir: str | Path = "") -> str:
|
||||
def flag(
|
||||
self,
|
||||
payload: Any,
|
||||
flag_dir: str | Path = "", # noqa: ARG002
|
||||
) -> str:
|
||||
return json.dumps(payload)
|
||||
|
||||
def read_from_flag(self, payload: Any, flag_dir: str | Path | None = None):
|
||||
def read_from_flag(self, payload: Any, flag_dir: str | Path | None = None): # noqa: ARG002
|
||||
return json.loads(payload)
|
||||
|
||||
def api_info(self) -> dict[str, Any]:
|
||||
|
@ -28,7 +28,7 @@ class State(Component):
|
||||
def __init__(
|
||||
self,
|
||||
value: Any = None,
|
||||
render: bool = True,
|
||||
render: bool = True, # noqa: ARG002
|
||||
):
|
||||
"""
|
||||
Parameters:
|
||||
|
@ -150,13 +150,11 @@ class UploadButton(Component):
|
||||
if self.file_count == "single":
|
||||
if isinstance(payload, ListFiles):
|
||||
return self._process_single_file(payload[0])
|
||||
else:
|
||||
return self._process_single_file(payload)
|
||||
else:
|
||||
if isinstance(payload, ListFiles):
|
||||
return [self._process_single_file(f) for f in payload] # type: ignore
|
||||
else:
|
||||
return [self._process_single_file(payload)] # type: ignore
|
||||
return self._process_single_file(payload)
|
||||
|
||||
if isinstance(payload, ListFiles):
|
||||
return [self._process_single_file(f) for f in payload] # type: ignore
|
||||
return [self._process_single_file(payload)] # type: ignore
|
||||
|
||||
def postprocess(self, value: str | list[str] | None) -> ListFiles | FileData | None:
|
||||
"""
|
||||
|
@ -54,7 +54,7 @@ else:
|
||||
return super().dict(**kwargs)["root"]
|
||||
|
||||
@classmethod
|
||||
def schema(cls, **kwargs):
|
||||
def schema(cls, **_kwargs):
|
||||
# XXX: kwargs are ignored.
|
||||
return schema_of(cls.__fields__["root"].type_) # type: ignore
|
||||
|
||||
|
@ -129,7 +129,7 @@ class EventListenerMethod:
|
||||
|
||||
|
||||
class EventListener(str):
|
||||
def __new__(cls, event_name, *args, **kwargs):
|
||||
def __new__(cls, event_name, *_args, **_kwargs):
|
||||
return super().__new__(cls, event_name)
|
||||
|
||||
def __init__(
|
||||
|
@ -163,7 +163,7 @@ def from_model(model_name: str, hf_token: str | None, alias: str | None, **kwarg
|
||||
sources=["upload"], type="filepath", label="Input", render=False
|
||||
),
|
||||
"outputs": components.Label(label="Class", render=False),
|
||||
"preprocess": lambda i: to_binary,
|
||||
"preprocess": lambda _: to_binary,
|
||||
"postprocess": lambda r: postprocess_label(
|
||||
{i["label"].split(", ")[0]: i["score"] for i in r.json()}
|
||||
),
|
||||
|
@ -87,8 +87,8 @@ class SimpleCSVLogger(FlaggingCallback):
|
||||
def flag(
|
||||
self,
|
||||
flag_data: list[Any],
|
||||
flag_option: str = "",
|
||||
username: str | None = None,
|
||||
flag_option: str = "", # noqa: ARG002
|
||||
username: str | None = None, # noqa: ARG002
|
||||
) -> int:
|
||||
flagging_dir = self.flagging_dir
|
||||
log_filepath = Path(flagging_dir) / "log.csv"
|
||||
|
@ -755,15 +755,15 @@ def special_args(
|
||||
elif type_hint == routes.Request:
|
||||
if inputs is not None:
|
||||
inputs.insert(i, request)
|
||||
elif (
|
||||
type_hint == Optional[oauth.OAuthProfile]
|
||||
or type_hint == oauth.OAuthProfile
|
||||
or type_hint == Optional[oauth.OAuthToken]
|
||||
or type_hint == oauth.OAuthToken
|
||||
elif type_hint in (
|
||||
# Note: "OAuthProfile | None" is equals to Optional[OAuthProfile] in Python
|
||||
# => it is automatically handled as well by the above condition
|
||||
# (adding explicit "OAuthProfile | None" would break in Python3.9)
|
||||
# (same for "OAuthToken")
|
||||
Optional[oauth.OAuthProfile],
|
||||
Optional[oauth.OAuthToken],
|
||||
oauth.OAuthProfile,
|
||||
oauth.OAuthToken,
|
||||
):
|
||||
if inputs is not None:
|
||||
# Retrieve session from gr.Request, if it exists (i.e. if user is logged in)
|
||||
@ -776,10 +776,7 @@ def special_args(
|
||||
)
|
||||
|
||||
# Inject user profile
|
||||
if (
|
||||
type_hint == Optional[oauth.OAuthProfile]
|
||||
or type_hint == oauth.OAuthProfile
|
||||
):
|
||||
if type_hint in (Optional[oauth.OAuthProfile], oauth.OAuthProfile):
|
||||
oauth_profile = (
|
||||
session["oauth_info"]["userinfo"]
|
||||
if "oauth_info" in session
|
||||
@ -794,10 +791,7 @@ def special_args(
|
||||
inputs.insert(i, oauth_profile)
|
||||
|
||||
# Inject user token
|
||||
elif (
|
||||
type_hint == Optional[oauth.OAuthToken]
|
||||
or type_hint == oauth.OAuthToken
|
||||
):
|
||||
elif type_hint in (Optional[oauth.OAuthToken], oauth.OAuthToken):
|
||||
oauth_info = (
|
||||
session["oauth_info"] if "oauth_info" in session else None
|
||||
)
|
||||
|
@ -746,7 +746,7 @@ class Interface(Blocks):
|
||||
[],
|
||||
([input_component_column] if input_component_column else []), # type: ignore
|
||||
js=f"""() => {json.dumps(
|
||||
|
||||
|
||||
[{'variant': None, 'visible': True, '__type__': 'update'}]
|
||||
if self.interface_type
|
||||
in [
|
||||
@ -755,7 +755,7 @@ class Interface(Blocks):
|
||||
InterfaceTypes.UNIFIED,
|
||||
]
|
||||
else []
|
||||
|
||||
|
||||
)}
|
||||
""",
|
||||
)
|
||||
|
@ -32,7 +32,7 @@ LOCALHOST_NAME = os.getenv("GRADIO_SERVER_NAME", "127.0.0.1")
|
||||
GRADIO_API_SERVER = "https://api.gradio.app/v2/tunnel-request"
|
||||
GRADIO_SHARE_SERVER_ADDRESS = os.getenv("GRADIO_SHARE_SERVER_ADDRESS")
|
||||
|
||||
should_watch = bool(os.getenv("GRADIO_WATCH_DIRS", False))
|
||||
should_watch = bool(os.getenv("GRADIO_WATCH_DIRS", ""))
|
||||
GRADIO_WATCH_DIRS = (
|
||||
os.getenv("GRADIO_WATCH_DIRS", "").split(",") if should_watch else []
|
||||
)
|
||||
|
@ -121,7 +121,7 @@ def _add_mocked_oauth_routes(app: fastapi.FastAPI) -> None:
|
||||
|
||||
# Define OAuth routes
|
||||
@app.get("/login/huggingface")
|
||||
async def oauth_login(request: fastapi.Request):
|
||||
async def oauth_login(request: fastapi.Request): # noqa: ARG001
|
||||
"""Fake endpoint that redirects to HF OAuth page."""
|
||||
# Define target (where to redirect after login)
|
||||
redirect_uri = _generate_redirect_uri(request)
|
||||
|
@ -91,7 +91,7 @@ class RangedFileResponse(Response):
|
||||
self.headers["content-length"] = str(content_length)
|
||||
pass
|
||||
|
||||
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
||||
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: # noqa: ARG002
|
||||
if self.stat_result is None:
|
||||
try:
|
||||
stat_result = await aio_stat(self.path)
|
||||
|
@ -49,8 +49,7 @@ from starlette.background import BackgroundTask
|
||||
from starlette.responses import RedirectResponse, StreamingResponse
|
||||
|
||||
import gradio
|
||||
import gradio.ranged_response as ranged_response
|
||||
from gradio import route_utils, utils, wasm_utils
|
||||
from gradio import ranged_response, route_utils, utils, wasm_utils
|
||||
from gradio.context import Context
|
||||
from gradio.data_classes import ComponentServerBody, PredictBody, ResetBody
|
||||
from gradio.exceptions import Error
|
||||
@ -238,7 +237,7 @@ class App(FastAPI):
|
||||
|
||||
@app.get("/app_id")
|
||||
@app.get("/app_id/")
|
||||
def app_id(request: fastapi.Request) -> dict:
|
||||
def app_id(request: fastapi.Request) -> dict: # noqa: ARG001
|
||||
return {"app_id": app.get_blocks().app_id}
|
||||
|
||||
@app.get("/dev/reload", dependencies=[Depends(login_check)])
|
||||
@ -354,8 +353,7 @@ class App(FastAPI):
|
||||
|
||||
@app.get("/info/", dependencies=[Depends(login_check)])
|
||||
@app.get("/info", dependencies=[Depends(login_check)])
|
||||
def api_info(serialize: bool = True):
|
||||
# config = app.get_blocks().get_api_info()
|
||||
def api_info():
|
||||
return app.get_blocks().get_api_info() # type: ignore
|
||||
|
||||
@app.get("/config/", dependencies=[Depends(login_check)])
|
||||
@ -489,7 +487,10 @@ class App(FastAPI):
|
||||
dependencies=[Depends(login_check)],
|
||||
)
|
||||
async def stream(
|
||||
session_hash: str, run: int, component_id: int, request: fastapi.Request
|
||||
session_hash: str,
|
||||
run: int,
|
||||
component_id: int,
|
||||
request: fastapi.Request, # noqa: ARG001
|
||||
):
|
||||
stream: list = (
|
||||
app.get_blocks()
|
||||
|
@ -71,7 +71,7 @@ with gr.Blocks(theme=gr.themes.Default()) as demo:
|
||||
go_btn = gr.Button("Go", variant="primary")
|
||||
clear_btn = gr.Button("Clear", variant="secondary")
|
||||
|
||||
def go(*args):
|
||||
def go(*_args):
|
||||
time.sleep(3)
|
||||
return "https://gradio-static-files.s3.us-west-2.amazonaws.com/header-image.jpg"
|
||||
|
||||
@ -79,7 +79,6 @@ with gr.Blocks(theme=gr.themes.Default()) as demo:
|
||||
|
||||
def clear():
|
||||
time.sleep(0.2)
|
||||
return None
|
||||
|
||||
clear_btn.click(clear, None, img)
|
||||
|
||||
@ -121,10 +120,6 @@ with gr.Blocks(theme=gr.themes.Default()) as demo:
|
||||
chatbot = gr.Chatbot([("Hello", "Hi")], label="Chatbot")
|
||||
chat_btn = gr.Button("Add messages")
|
||||
|
||||
def chat(history):
|
||||
time.sleep(2)
|
||||
yield [["How are you?", "I am good."]]
|
||||
|
||||
chat_btn.click(
|
||||
lambda history: history
|
||||
+ [["How are you?", "I am good."]]
|
||||
|
@ -359,7 +359,7 @@ with gr.Blocks( # noqa: SIM117
|
||||
go_btn = gr.Button("Go", variant="primary")
|
||||
clear_btn = gr.Button("Clear", variant="secondary")
|
||||
|
||||
def go(*args):
|
||||
def go(*_args):
|
||||
time.sleep(3)
|
||||
return "https://gradio-static-files.s3.us-west-2.amazonaws.com/header-image.jpg"
|
||||
|
||||
@ -372,7 +372,6 @@ with gr.Blocks( # noqa: SIM117
|
||||
|
||||
def clear():
|
||||
time.sleep(0.2)
|
||||
return None
|
||||
|
||||
clear_btn.click(clear, None, img)
|
||||
|
||||
@ -437,10 +436,6 @@ with gr.Blocks( # noqa: SIM117
|
||||
chatbot = gr.Chatbot([("Hello", "Hi")], label="Chatbot")
|
||||
chat_btn = gr.Button("Add messages")
|
||||
|
||||
def chat(history):
|
||||
time.sleep(2)
|
||||
yield [["How are you?", "I am good."]]
|
||||
|
||||
chat_btn.click(
|
||||
lambda history: history
|
||||
+ [["How are you?", "I am good."]]
|
||||
|
@ -591,7 +591,7 @@ def validate_url(possible_url: str) -> bool:
|
||||
try:
|
||||
head_request = httpx.head(possible_url, headers=headers, follow_redirects=True)
|
||||
# some URLs, such as AWS S3 presigned URLs, return a 405 or a 403 for HEAD requests
|
||||
if head_request.status_code == 405 or head_request.status_code == 403:
|
||||
if head_request.status_code in (403, 405):
|
||||
return httpx.get(possible_url, headers=headers).is_success
|
||||
return head_request.is_success
|
||||
except Exception:
|
||||
@ -891,7 +891,7 @@ class MatplotlibBackendMananger:
|
||||
matplotlib.use(self._original_backend)
|
||||
|
||||
|
||||
def tex2svg(formula, *args):
|
||||
def tex2svg(formula, *_args):
|
||||
with MatplotlibBackendMananger():
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
|
@ -37,8 +37,6 @@ def app_id_context(app_id: str):
|
||||
# for the Wasm worker to get a reference to
|
||||
# the Gradio's FastAPI app instance (`app`).
|
||||
def register_app(_app):
|
||||
global app_map
|
||||
|
||||
app_id = _app_id_context_var.get()
|
||||
|
||||
if app_id in app_map:
|
||||
@ -49,5 +47,4 @@ def register_app(_app):
|
||||
|
||||
|
||||
def get_registered_app(app_id: str):
|
||||
global app_map
|
||||
return app_map[app_id]
|
||||
|
@ -22,7 +22,7 @@ if __name__ == "__main__":
|
||||
custom_component = ("gradio-custom-component" in keywords or
|
||||
"gradio custom component" in keywords)
|
||||
if not custom_component:
|
||||
exit(0)
|
||||
sys.exit(0)
|
||||
|
||||
module_name = pyproject_toml["project"]["name"]
|
||||
module = importlib.import_module(module_name)
|
||||
|
@ -1,7 +1,7 @@
|
||||
|
||||
import gradio as gr
|
||||
from gradio_test import Test
|
||||
|
||||
import gradio as gr
|
||||
|
||||
example = Test().example_inputs()
|
||||
|
||||
|
@ -43,7 +43,7 @@ from gradio.wasm_utils import app_id_context
|
||||
|
||||
# Code modified from IPython (BSD license)
|
||||
# Source: https://github.com/ipython/ipython/blob/master/IPython/utils/syspathcontext.py#L42
|
||||
class modified_sys_path:
|
||||
class modified_sys_path: # noqa: N801
|
||||
"""A context for prepending a directory to sys.path for a second."""
|
||||
|
||||
def __init__(self, script_path: str):
|
||||
|
@ -13,8 +13,8 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
import fnmatch
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import types
|
||||
@ -109,7 +109,7 @@ def get_module_paths(module: types.ModuleType) -> Set[str]:
|
||||
# Handling of "namespace packages" in which the __path__ attribute
|
||||
# is a _NamespacePath object with a _path attribute containing
|
||||
# the various paths of the package.
|
||||
lambda m: [p for p in m.__path__._path],
|
||||
lambda m: list(m.__path__._path),
|
||||
]
|
||||
|
||||
all_paths = set()
|
||||
|
@ -105,24 +105,39 @@ exclude = [
|
||||
]
|
||||
|
||||
[tool.ruff]
|
||||
target-version = "py37"
|
||||
extend-select = ["B", "C", "I", "N", "SIM", "UP"]
|
||||
extend-select = [
|
||||
"ARG",
|
||||
"B",
|
||||
"C",
|
||||
"E",
|
||||
"F",
|
||||
"I",
|
||||
"N",
|
||||
"PL",
|
||||
"SIM",
|
||||
"UP",
|
||||
"W",
|
||||
]
|
||||
ignore = [
|
||||
"C901", # function is too complex (TODO: un-ignore this)
|
||||
"B023", # function definition in loop (TODO: un-ignore this)
|
||||
"B008", # function call in argument defaults
|
||||
"B017", # pytest.raises considered evil
|
||||
"B028", # explicit stacklevel for warnings
|
||||
"E501", # from scripts/lint_backend.sh
|
||||
"SIM105", # contextlib.suppress (has a performance cost)
|
||||
"SIM117", # multiple nested with blocks (doesn't look good with gr.Row etc)
|
||||
"UP007", # use X | Y for type annotations (TODO: can be enabled once Pydantic plays nice with them)
|
||||
"UP006", # use `list` instead of `List` for type annotations (fails for 3.8)
|
||||
"B008", # function call in argument defaults
|
||||
"B017", # pytest.raises considered evil
|
||||
"B023", # function definition in loop (TODO: un-ignore this)
|
||||
"B028", # explicit stacklevel for warnings
|
||||
"C901", # function is too complex (TODO: un-ignore this)
|
||||
"E501", # from scripts/lint_backend.sh
|
||||
"PLR091", # complexity rules
|
||||
"PLR2004", # magic numbers
|
||||
"PLW2901", # `for` loop variable overwritten by assignment target
|
||||
"SIM105", # contextlib.suppress (has a performance cost)
|
||||
"SIM117", # multiple nested with blocks (doesn't look good with gr.Row etc)
|
||||
"UP006", # use `list` instead of `List` for type annotations (fails for 3.8)
|
||||
"UP007", # use X | Y for type annotations (TODO: can be enabled once Pydantic plays nice with them)
|
||||
]
|
||||
exclude = ["gradio/node/*.py", ".venv/*", "gradio/_frontend_code/*.py"]
|
||||
|
||||
[tool.ruff.per-file-ignores]
|
||||
"demo/*" = [
|
||||
"ARG",
|
||||
"E402", # Demos may have imports not at the top
|
||||
"E741", # Demos may have ambiguous variable names
|
||||
"F405", # Demos may use star imports
|
||||
@ -135,6 +150,15 @@ exclude = ["gradio/node/*.py", ".venv/*", "gradio/_frontend_code/*.py"]
|
||||
"UP006", # Pydantic on Python 3.7 requires old-style type annotations (TODO: drop when Python 3.7 is dropped)
|
||||
]
|
||||
"gradio/cli/commands/files/NoTemplateComponent.py" = ["ALL"]
|
||||
"client/python/gradio_client/serializing.py" = [
|
||||
"ARG", # contains backward compatibility code, so args need to be named as such
|
||||
]
|
||||
"client/python/test/*" = [
|
||||
"ARG",
|
||||
]
|
||||
"test/*" = [
|
||||
"ARG",
|
||||
]
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
GRADIO_ANALYTICS_ENABLED = "False"
|
||||
|
@ -13,7 +13,7 @@ pydantic[email]
|
||||
pytest
|
||||
pytest-asyncio
|
||||
pytest-cov
|
||||
ruff>=0.1.7
|
||||
ruff>=0.1.13
|
||||
respx
|
||||
scikit-image
|
||||
torch
|
||||
|
@ -161,7 +161,7 @@ requests==2.28.1
|
||||
# transformers
|
||||
respx==0.19.2
|
||||
# via -r requirements.in
|
||||
ruff==0.1.7
|
||||
ruff==0.1.13
|
||||
# via -r requirements.in
|
||||
rfc3986[idna2008]==1.5.0
|
||||
# via httpx
|
||||
|
@ -3,7 +3,7 @@ import ipaddress
|
||||
import json
|
||||
import os
|
||||
import warnings
|
||||
from unittest import mock as mock
|
||||
from unittest.mock import patch
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
@ -15,7 +15,7 @@ os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"
|
||||
|
||||
|
||||
class TestAnalytics:
|
||||
@mock.patch("httpx.get")
|
||||
@patch("httpx.get")
|
||||
def test_should_warn_with_unable_to_parse(self, mock_get, monkeypatch):
|
||||
monkeypatch.setenv("GRADIO_ANALYTICS_ENABLED", "True")
|
||||
mock_get.side_effect = json.decoder.JSONDecodeError("Expecting value", "", 0)
|
||||
@ -28,7 +28,7 @@ class TestAnalytics:
|
||||
== "unable to parse version details from package URL."
|
||||
)
|
||||
|
||||
@mock.patch("httpx.post")
|
||||
@patch("httpx.post")
|
||||
def test_error_analytics_doesnt_crash_on_connection_error(
|
||||
self, mock_post, monkeypatch
|
||||
):
|
||||
@ -37,14 +37,14 @@ class TestAnalytics:
|
||||
analytics._do_normal_analytics_request("placeholder", {})
|
||||
mock_post.assert_called()
|
||||
|
||||
@mock.patch("httpx.post")
|
||||
@patch("httpx.post")
|
||||
def test_error_analytics_successful(self, mock_post, monkeypatch):
|
||||
monkeypatch.setenv("GRADIO_ANALYTICS_ENABLED", "True")
|
||||
analytics.error_analytics("placeholder")
|
||||
mock_post.assert_called()
|
||||
|
||||
@mock.patch.object(wasm_utils, "IS_WASM", True)
|
||||
@mock.patch("gradio.analytics.pyodide_pyfetch")
|
||||
@patch.object(wasm_utils, "IS_WASM", True)
|
||||
@patch("gradio.analytics.pyodide_pyfetch")
|
||||
@pytest.mark.asyncio
|
||||
async def test_error_analytics_successful_in_wasm_mode(
|
||||
self, pyodide_pyfetch, monkeypatch
|
||||
@ -69,11 +69,11 @@ class TestIPAddress:
|
||||
def test_get_ip(self):
|
||||
Context.ip_address = None
|
||||
ip = analytics.get_local_ip_address()
|
||||
if ip == "No internet connection" or ip == "Analytics disabled":
|
||||
if ip in ("No internet connection", "Analytics disabled"):
|
||||
return
|
||||
ipaddress.ip_address(ip)
|
||||
|
||||
@mock.patch("httpx.get")
|
||||
@patch("httpx.get")
|
||||
def test_get_ip_without_internet(self, mock_get, monkeypatch):
|
||||
mock_get.side_effect = httpx.ConnectError("Connection error")
|
||||
monkeypatch.setenv("GRADIO_ANALYTICS_ENABLED", "True")
|
||||
|
@ -6,7 +6,6 @@ import pathlib
|
||||
import random
|
||||
import sys
|
||||
import time
|
||||
import unittest.mock as mock
|
||||
import uuid
|
||||
from concurrent.futures import wait
|
||||
from contextlib import contextmanager
|
||||
@ -142,14 +141,14 @@ class TestBlocksMethods:
|
||||
assert difference >= 0.01
|
||||
assert result
|
||||
|
||||
@mock.patch("gradio.analytics._do_analytics_request")
|
||||
@patch("gradio.analytics._do_analytics_request")
|
||||
def test_initiated_analytics(self, mock_anlaytics, monkeypatch):
|
||||
monkeypatch.setenv("GRADIO_ANALYTICS_ENABLED", "True")
|
||||
with gr.Blocks():
|
||||
pass
|
||||
mock_anlaytics.assert_called_once()
|
||||
|
||||
@mock.patch("gradio.analytics._do_analytics_request")
|
||||
@patch("gradio.analytics._do_analytics_request")
|
||||
def test_launch_analytics_does_not_error_with_invalid_blocks(
|
||||
self, mock_anlaytics, monkeypatch
|
||||
):
|
||||
|
@ -1,5 +1,5 @@
|
||||
"""
|
||||
This suite of tests is designed to ensure compatibility between the current version of Gradio
|
||||
This suite of tests is designed to ensure compatibility between the current version of Gradio
|
||||
with custom components created using the previous version of Gradio.
|
||||
"""
|
||||
from pathlib import Path
|
||||
|
@ -187,7 +187,7 @@ with gr.Blocks() as demo:
|
||||
|
||||
if __name__ == "__main__":
|
||||
demo.launch()
|
||||
|
||||
|
||||
"""
|
||||
)
|
||||
assert app.strip() == answer.strip()
|
||||
|
@ -1,9 +1,9 @@
|
||||
import io
|
||||
import sys
|
||||
import unittest.mock as mock
|
||||
from contextlib import contextmanager
|
||||
from functools import partial
|
||||
from string import capwords
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
@ -39,7 +39,7 @@ class TestInterface:
|
||||
|
||||
def test_close_all(self):
|
||||
interface = Interface(lambda input: None, "textbox", "label")
|
||||
interface.close = mock.MagicMock()
|
||||
interface.close = MagicMock()
|
||||
close_all()
|
||||
interface.close.assert_called()
|
||||
|
||||
@ -90,7 +90,7 @@ class TestInterface:
|
||||
)
|
||||
assert dataset_check
|
||||
|
||||
@mock.patch("time.sleep")
|
||||
@patch("time.sleep")
|
||||
def test_block_thread(self, mock_sleep):
|
||||
with pytest.raises(KeyboardInterrupt):
|
||||
with captured_output() as (out, _):
|
||||
@ -102,8 +102,8 @@ class TestInterface:
|
||||
"Keyboard interruption in main thread... closing server." in output
|
||||
)
|
||||
|
||||
@mock.patch("gradio.utils.colab_check")
|
||||
@mock.patch("gradio.networking.setup_tunnel")
|
||||
@patch("gradio.utils.colab_check")
|
||||
@patch("gradio.networking.setup_tunnel")
|
||||
def test_launch_colab_share_error(self, mock_setup_tunnel, mock_colab_check):
|
||||
mock_setup_tunnel.side_effect = RuntimeError()
|
||||
mock_colab_check.return_value = True
|
||||
@ -121,7 +121,7 @@ class TestInterface:
|
||||
assert prediction_fn.__name__ in repr[0]
|
||||
assert len(repr[0]) == len(repr[1])
|
||||
|
||||
@mock.patch("webbrowser.open")
|
||||
@patch("webbrowser.open")
|
||||
def test_interface_browser(self, mock_browser):
|
||||
interface = Interface(lambda x: x, "textbox", "label")
|
||||
interface.launch(inbrowser=True, prevent_thread_lock=True)
|
||||
@ -139,7 +139,7 @@ class TestInterface:
|
||||
assert interface.examples_handler.dataset.get_config()["samples_per_page"] == 2
|
||||
interface.close()
|
||||
|
||||
@mock.patch("IPython.display.display")
|
||||
@patch("IPython.display.display")
|
||||
def test_inline_display(self, mock_display):
|
||||
interface = Interface(lambda x: x, "textbox", "label")
|
||||
interface.launch(inline=True, prevent_thread_lock=True)
|
||||
|
@ -3,7 +3,7 @@ import functools
|
||||
import os
|
||||
import tempfile
|
||||
from contextlib import asynccontextmanager, closing
|
||||
from unittest import mock as mock
|
||||
from unittest.mock import patch
|
||||
|
||||
import gradio_client as grc
|
||||
import numpy as np
|
||||
@ -818,14 +818,14 @@ def test_api_name_set_for_all_events(connect):
|
||||
|
||||
|
||||
class TestShowAPI:
|
||||
@mock.patch.object(wasm_utils, "IS_WASM", True)
|
||||
@patch.object(wasm_utils, "IS_WASM", True)
|
||||
def test_show_api_false_when_is_wasm_true(self):
|
||||
interface = Interface(lambda x: x, "text", "text", examples=[["hannah"]])
|
||||
assert (
|
||||
interface.show_api is False
|
||||
), "show_api should be False when IS_WASM is True"
|
||||
|
||||
@mock.patch.object(wasm_utils, "IS_WASM", False)
|
||||
@patch.object(wasm_utils, "IS_WASM", False)
|
||||
def test_show_api_true_when_is_wasm_false(self):
|
||||
interface = Interface(lambda x: x, "text", "text", examples=[["hannah"]])
|
||||
assert (
|
||||
|
@ -3,10 +3,9 @@ from __future__ import annotations
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import unittest.mock as mock
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
@ -40,22 +39,22 @@ os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"
|
||||
|
||||
|
||||
class TestUtils:
|
||||
@mock.patch("IPython.get_ipython")
|
||||
@patch("IPython.get_ipython")
|
||||
def test_colab_check_no_ipython(self, mock_get_ipython):
|
||||
mock_get_ipython.return_value = None
|
||||
assert colab_check() is False
|
||||
|
||||
@mock.patch("IPython.get_ipython")
|
||||
@patch("IPython.get_ipython")
|
||||
def test_ipython_check_import_fail(self, mock_get_ipython):
|
||||
mock_get_ipython.side_effect = ImportError()
|
||||
assert ipython_check() is False
|
||||
|
||||
@mock.patch("IPython.get_ipython")
|
||||
@patch("IPython.get_ipython")
|
||||
def test_ipython_check_no_ipython(self, mock_get_ipython):
|
||||
mock_get_ipython.return_value = None
|
||||
assert ipython_check() is False
|
||||
|
||||
@mock.patch("httpx.get")
|
||||
@patch("httpx.get")
|
||||
def test_readme_to_html_doesnt_crash_on_connection_error(self, mock_get):
|
||||
mock_get.side_effect = httpx.ConnectError("Connection error")
|
||||
readme_to_html("placeholder")
|
||||
@ -67,10 +66,10 @@ class TestUtils:
|
||||
assert not sagemaker_check()
|
||||
|
||||
def test_sagemaker_check_false_if_boto3_not_installed(self):
|
||||
with mock.patch.dict(sys.modules, {"boto3": None}, clear=True):
|
||||
with patch.dict(sys.modules, {"boto3": None}, clear=True):
|
||||
assert not sagemaker_check()
|
||||
|
||||
@mock.patch("boto3.session.Session.client")
|
||||
@patch("boto3.session.Session.client")
|
||||
def test_sagemaker_check_true(self, mock_client):
|
||||
mock_client().get_caller_identity = MagicMock(
|
||||
return_value={
|
||||
@ -83,13 +82,13 @@ class TestUtils:
|
||||
assert not kaggle_check()
|
||||
|
||||
def test_kaggle_check_true_when_run_type_set(self):
|
||||
with mock.patch.dict(
|
||||
with patch.dict(
|
||||
os.environ, {"KAGGLE_KERNEL_RUN_TYPE": "Interactive"}, clear=True
|
||||
):
|
||||
assert kaggle_check()
|
||||
|
||||
def test_kaggle_check_true_when_both_set(self):
|
||||
with mock.patch.dict(
|
||||
with patch.dict(
|
||||
os.environ,
|
||||
{"KAGGLE_KERNEL_RUN_TYPE": "Interactive", "GFOOTBALL_DATA_DIR": "./"},
|
||||
clear=True,
|
||||
@ -97,7 +96,7 @@ class TestUtils:
|
||||
assert kaggle_check()
|
||||
|
||||
def test_kaggle_check_false_when_neither_set(self):
|
||||
with mock.patch.dict(
|
||||
with patch.dict(
|
||||
os.environ,
|
||||
{"KAGGLE_KERNEL_RUN_TYPE": "", "GFOOTBALL_DATA_DIR": ""},
|
||||
clear=True,
|
||||
|
Loading…
x
Reference in New Issue
Block a user