2
0
mirror of https://github.com/gradio-app/gradio.git synced 2025-04-18 12:50:30 +08:00

Ruff update + strictening ()

* Update ruff to 0.0.263

* Get rid of bare except:s

* Fix two Ruff E731s by moving expand_{color,size} to the relevant classes

* Fix Ruff E731 and some variable shadowing in theme builder

* Fix remaining Ruff E731s

* Get rid of unused Ruff ignores

* Fix ruff B904 issues (raise from)

* Fix Ruff B007: loop-control variable not used

* Fix Ruff B011 (do not assert false)

* Remove unused args and kwargs from Progress.tqdm() (spotted via Ruff B026)

* Fix mutable argument default in CheckboxGroup

* Noqa ABC-related lint warnings for Interpretable

* Add missed assert in test_queueing (ruff B015)

* Enable ruff B

* Enable ruff C and fix issues

* Add changelog

* Add UP03[012] after 

---------

Co-authored-by: Abubakar Abid <abubakar@huggingface.co>
This commit is contained in:
Aarni Koskela 2023-04-29 00:59:42 +03:00 committed by GitHub
parent 1f9584f9a7
commit 12a97746ff
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
32 changed files with 198 additions and 191 deletions

@ -88,6 +88,7 @@ No changes to highlight.
- CI: Python backend lint is only run once, by [@akx](https://github.com/akx) in [PR 3960](https://github.com/gradio-app/gradio/pull/3960)
- Format invocations and concatenations were replaced by f-strings where possible by [@akx](https://github.com/akx) in [PR 3984](https://github.com/gradio-app/gradio/pull/3984)
- Linting rules were made more strict and issues fixed by [@akx](https://github.com/akx) in [PR 3979](https://github.com/gradio-app/gradio/pull/3979).
## Breaking Changes:

@ -169,10 +169,10 @@ class Client:
"""
try:
original_info = huggingface_hub.get_space_runtime(from_id, token=hf_token)
except RepositoryNotFoundError:
except RepositoryNotFoundError as rnfe:
raise ValueError(
f"Could not find Space: {from_id}. If it is a private Space, please provide an `hf_token`."
)
) from rnfe
if to_id:
if "/" in to_id:
to_id = to_id.split("/")[1]
@ -554,8 +554,10 @@ class Client:
result = re.search(r"window.gradio_config = (.*?);[\s]*</script>", r.text)
try:
config = json.loads(result.group(1)) # type: ignore
except AttributeError:
raise ValueError(f"Could not get Gradio config from: {self.src}")
except AttributeError as ae:
raise ValueError(
f"Could not get Gradio config from: {self.src}"
) from ae
if "allow_flagging" in config:
raise ValueError(
"Gradio 2.x is not supported by this client. Please upgrade your Gradio app to Gradio 3.x or higher."
@ -639,20 +641,22 @@ class Endpoint:
result = json.loads(response.content.decode("utf-8"))
try:
output = result["data"]
except KeyError:
except KeyError as ke:
is_public_space = (
self.client.space_id
and not huggingface_hub.space_info(self.client.space_id).private
)
if "error" in result and "429" in result["error"] and is_public_space:
raise utils.TooManyRequestsError(
f"Too many requests to the API, please try again later. To avoid being rate-limited, please duplicate the Space using Client.duplicate({self.client.space_id}) and pass in your Hugging Face token."
)
f"Too many requests to the API, please try again later. To avoid being rate-limited, "
f"please duplicate the Space using Client.duplicate({self.client.space_id}) "
f"and pass in your Hugging Face token."
) from None
elif "error" in result:
raise ValueError(result["error"])
raise ValueError(result["error"]) from None
raise KeyError(
f"Could not find 'data' key in response. Response received: {result}"
)
) from ke
return tuple(output)
return _predict

@ -463,11 +463,11 @@ def set_space_timeout(
)
try:
huggingface_hub.utils.hf_raise_for_status(r)
except huggingface_hub.utils.HfHubHTTPError:
except huggingface_hub.utils.HfHubHTTPError as err:
raise SpaceDuplicationError(
f"Could not set sleep timeout on duplicated Space. Please visit {SPACE_URL.format(space_id)} "
"to set a timeout manually to reduce billing charges."
)
) from err
########################

@ -1,6 +1,6 @@
black==22.6.0
pytest-asyncio
pytest==7.1.2
ruff==0.0.260
ruff==0.0.263
pyright==1.1.298
gradio

@ -987,12 +987,7 @@ class Blocks(BlockContext):
is_generating = False
if block_fn.inputs_as_dict:
processed_input = [
{
input_component: data
for input_component, data in zip(block_fn.inputs, processed_input)
}
]
processed_input = [dict(zip(block_fn.inputs, processed_input))]
if isinstance(requests, list):
request = requests[0]
@ -1211,10 +1206,11 @@ Received outputs:
if predictions[i] is components._Keywords.FINISHED_ITERATING:
output.append(None)
continue
except (IndexError, KeyError):
except (IndexError, KeyError) as err:
raise ValueError(
f"Number of output components does not match number of values returned from from function {block_fn.name}"
)
"Number of output components does not match number "
f"of values returned from from function {block_fn.name}"
) from err
block = self.blocks[output_id]
if getattr(block, "stateful", False):
if not utils.is_update(predictions[i]):

@ -1846,10 +1846,10 @@ class Image(
resized_and_cropped_image = np.array(x)
try:
from skimage.segmentation import slic
except (ImportError, ModuleNotFoundError):
except (ImportError, ModuleNotFoundError) as err:
raise ValueError(
"Error: running this interpretation for images requires scikit-image, please install it first."
)
) from err
try:
segments_slic = slic(
resized_and_cropped_image,
@ -1880,7 +1880,7 @@ class Image(
segments_slic, resized_and_cropped_image = self._segment_by_slic(x)
tokens, masks, leave_one_out_tokens = [], [], []
replace_color = np.mean(resized_and_cropped_image, axis=(0, 1))
for i, segment_value in enumerate(np.unique(segments_slic)):
for segment_value in np.unique(segments_slic):
mask = segments_slic == segment_value
image_screen = np.copy(resized_and_cropped_image)
image_screen[segments_slic == segment_value] = replace_color
@ -3863,10 +3863,11 @@ class HighlightedText(Changeable, Selectable, IOComponent, JSONSerializable):
try:
text = y["text"]
entities = y["entities"]
except KeyError:
except KeyError as ke:
raise ValueError(
"Expected a dictionary with keys 'text' and 'entities' for the value of the HighlightedText component."
)
"Expected a dictionary with keys 'text' and 'entities' "
"for the value of the HighlightedText component."
) from ke
if len(entities) == 0:
y = [(text, None)]
else:
@ -4049,7 +4050,7 @@ class AnnotatedImage(Selectable, IOComponent, JSONSerializable):
def hex_to_rgb(value):
value = value.lstrip("#")
lv = len(value)
return list(int(value[i : i + lv // 3], 16) for i in range(0, lv, lv // 3))
return [int(value[i : i + lv // 3], 16) for i in range(0, lv, lv // 3)]
for mask, label in y[1]:
mask_array = np.zeros((base_img.shape[0], base_img.shape[1]))
@ -5142,18 +5143,18 @@ class ScatterPlot(Plot):
):
"""Helper for creating the scatter plot."""
interactive = True if interactive is None else interactive
encodings = dict(
x=alt.X(
encodings = {
"x": alt.X(
x, # type: ignore
title=x_title or x, # type: ignore
scale=AltairPlot.create_scale(x_lim), # type: ignore
), # ignore: type
y=alt.Y(
"y": alt.Y(
y, # type: ignore
title=y_title or y, # type: ignore
scale=AltairPlot.create_scale(y_lim), # type: ignore
),
)
}
properties = {}
if title:
properties["title"] = title
@ -5473,18 +5474,18 @@ class LinePlot(Plot):
):
"""Helper for creating the scatter plot."""
interactive = True if interactive is None else interactive
encodings = dict(
x=alt.X(
encodings = {
"x": alt.X(
x, # type: ignore
title=x_title or x, # type: ignore
scale=AltairPlot.create_scale(x_lim), # type: ignore
),
y=alt.Y(
"y": alt.Y(
y, # type: ignore
title=y_title or y, # type: ignore
scale=AltairPlot.create_scale(y_lim), # type: ignore
),
)
}
properties = {}
if title:
properties["title"] = title
@ -5796,10 +5797,10 @@ class BarPlot(Plot):
y_lim: List[int] | None = None,
interactive: bool | None = True,
):
"""Helper for creating the scatter plot."""
"""Helper for creating the bar plot."""
interactive = True if interactive is None else interactive
orientation = (
dict(field=group, title=group_title if group_title is not None else group)
{"field": group, "title": group_title if group_title is not None else group}
if group
else {}
)
@ -6124,7 +6125,7 @@ class Dataset(Clickable, Selectable, Component, StringSerializable):
# Narrow type to IOComponent
assert all(
[isinstance(c, IOComponent) for c in self.components]
isinstance(c, IOComponent) for c in self.components
), "All components in a `Dataset` must be subclasses of `IOComponent`"
self.components = [c for c in self.components if isinstance(c, IOComponent)]
for component in self.components:
@ -6138,7 +6139,7 @@ class Dataset(Clickable, Selectable, Component, StringSerializable):
self.label = label
if headers is not None:
self.headers = headers
elif all([c.label is None for c in self.components]):
elif all(c.label is None for c in self.components):
self.headers = []
else:
self.headers = [c.label or "" for c in self.components]

@ -434,8 +434,8 @@ def from_spaces(
) # some basic regex to extract the config
try:
config = json.loads(result.group(1)) # type: ignore
except AttributeError:
raise ValueError(f"Could not load the Space: {space_name}")
except AttributeError as ae:
raise ValueError(f"Could not load the Space: {space_name}") from ae
if "allow_flagging" in config: # Create an Interface for Gradio 2.x Spaces
return from_spaces_interface(
space_name, config, alias, api_key, iframe_url, **kwargs
@ -477,14 +477,14 @@ def from_spaces_interface(
data = json.dumps({"data": data})
response = requests.post(api_url, headers=headers, data=data)
result = json.loads(response.content.decode("utf-8"))
if "error" in result and "429" in result["error"]:
raise TooManyRequestsError("Too many requests to the Hugging Face API")
try:
output = result["data"]
except KeyError:
if "error" in result and "429" in result["error"]:
raise TooManyRequestsError("Too many requests to the Hugging Face API")
except KeyError as ke:
raise KeyError(
f"Could not find 'data' key in response from external Space. Response received: {result}"
)
) from ke
if (
len(config["outputs"]) == 1
): # if the fn is supposed to return a single value, pop it

@ -99,12 +99,13 @@ def encode_to_base64(r: requests.Response) -> str:
# Case 2: the data prefix is a key in the response
if content_type == "application/json":
try:
content_type = r.json()[0]["content-type"]
base64_repr = r.json()[0]["blob"]
except KeyError:
data = r.json()[0]
content_type = data["content-type"]
base64_repr = data["blob"]
except KeyError as ke:
raise ValueError(
"Cannot determine content type returned" "by external API."
)
"Cannot determine content type returned by external API."
) from ke
# Case 3: the data prefix is included in the response headers
else:
pass

@ -276,11 +276,11 @@ class HuggingFaceDatasetSaver(FlaggingCallback):
"""
try:
import huggingface_hub
except (ImportError, ModuleNotFoundError):
except (ImportError, ModuleNotFoundError) as err:
raise ImportError(
"Package `huggingface_hub` not found is needed "
"for HuggingFaceDatasetSaver. Try 'pip install huggingface_hub'."
)
) from err
hh_version = pkg_resources.get_distribution("huggingface_hub").version
try:
if StrictVersion(hh_version) < StrictVersion("0.6.0"):
@ -416,11 +416,11 @@ class HuggingFaceDatasetJSONSaver(FlaggingCallback):
"""
try:
import huggingface_hub
except (ImportError, ModuleNotFoundError):
except (ImportError, ModuleNotFoundError) as err:
raise ImportError(
"Package `huggingface_hub` not found is needed "
"for HuggingFaceDatasetJSONSaver. Try 'pip install huggingface_hub'."
)
) from err
hh_version = pkg_resources.get_distribution("huggingface_hub").version
try:
if StrictVersion(hh_version) < StrictVersion("0.6.0"):
@ -510,9 +510,7 @@ class HuggingFaceDatasetJSONSaver(FlaggingCallback):
csv_data.append(flag_option)
# Creates metadata dict from row data and dumps it
metadata_dict = {
header: _csv_data for header, _csv_data in zip(headers, csv_data)
}
metadata_dict = dict(zip(headers, csv_data))
self.dump_json(metadata_dict, Path(folder_name) / "metadata.jsonl")
if is_new:

@ -422,7 +422,7 @@ class Progress(Iterable):
return next(current_iterable.iterable) # type: ignore
except StopIteration:
self.iterables.pop()
raise StopIteration
raise
else:
return self
@ -463,8 +463,6 @@ class Progress(Iterable):
total: int | None = None,
unit: str = "steps",
_tqdm=None,
*args,
**kwargs,
):
"""
Attaches progress tracker to iterable, like tqdm.
@ -539,7 +537,7 @@ def create_tracker(root_blocks, event_id, fn, track_tqdm):
)
if self._progress is not None:
self._progress.event_id = event_id
self._progress.tqdm(iterable, desc, _tqdm=self, *args, **kwargs)
self._progress.tqdm(iterable, desc, _tqdm=self)
kwargs["file"] = open(os.devnull, "w")
self.__init__orig__(iterable, desc, *args, **kwargs)
@ -611,7 +609,7 @@ def special_args(
"""
signature = inspect.signature(fn)
positional_args = []
for i, param in enumerate(signature.parameters.values()):
for param in signature.parameters.values():
if param.kind not in (param.POSITIONAL_ONLY, param.POSITIONAL_OR_KEYWORD):
break
positional_args.append(param)

@ -133,7 +133,7 @@ class CheckboxGroup(components.CheckboxGroup):
def __init__(
self,
choices: List[str],
default: List[str] = [],
default: List[str] | None = None,
type: str = "value",
label: Optional[str] = None,
optional: bool = False,
@ -146,6 +146,8 @@ class CheckboxGroup(components.CheckboxGroup):
label (str): component name in interface.
optional (bool): this parameter is ignored.
"""
if default is None:
default = []
warnings.warn(
"Usage of gradio.inputs is deprecated, and will not be supported in the future, please import your component from gradio.components",
)

@ -349,9 +349,9 @@ class Interface(Blocks):
raise ValueError(
"flagging_options must be a list of strings or list of (string, string) tuples."
)
elif all([isinstance(x, str) for x in flagging_options]):
elif all(isinstance(x, str) for x in flagging_options):
self.flagging_options = [(f"Flag as {x}", x) for x in flagging_options]
elif all([isinstance(x, tuple) for x in flagging_options]):
elif all(isinstance(x, tuple) for x in flagging_options):
self.flagging_options = flagging_options
else:
raise ValueError(
@ -620,7 +620,7 @@ class Interface(Blocks):
for output in self.fn(*args):
if len(self.output_components) == 1 and not self.batch:
output = [output]
output = [o for o in output]
output = list(output)
yield output + [
Button.update(visible=False),
Button.update(visible=True),

@ -16,11 +16,11 @@ if TYPE_CHECKING: # Only import for type checking (is False at runtime).
from gradio import Interface
class Interpretable(ABC):
class Interpretable(ABC): # noqa: B024
def __init__(self) -> None:
self.set_interpret_parameters()
def set_interpret_parameters(self):
def set_interpret_parameters(self): # noqa: B027
"""
Set any parameters for interpretation. Properties can be set here to be
used in get_interpretation_neighbors and get_interpretation_scores.
@ -189,10 +189,10 @@ async def run_interpret(interface: Interface, raw_input: List):
elif interp == "shap" or interp == "shapley":
try:
import shap # type: ignore
except (ImportError, ModuleNotFoundError):
except (ImportError, ModuleNotFoundError) as err:
raise ValueError(
"The package `shap` is required for this interpretation method. Try: `pip install shap`"
)
) from err
input_component = interface.input_components[i]
if not isinstance(input_component, TokenInterpretable):
raise ValueError(

@ -117,10 +117,11 @@ def start_server(
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
s.bind((LOCALHOST_NAME, server_port))
s.close()
except OSError:
except OSError as err:
raise OSError(
f"Port {server_port} is in use. If a gradio.Blocks is running on the port, you can close() it or gradio.close_all()."
)
f"Port {server_port} is in use. If a gradio.Blocks is running on the port, "
f"you can close() it or gradio.close_all()."
) from err
port = server_port
url_host_name = "localhost" if server_name == "0.0.0.0" else server_name
@ -173,9 +174,8 @@ def setup_tunnel(local_host: str, local_port: int, share_token: str) -> str:
address = tunnel.start_tunnel()
return address
except Exception as e:
raise RuntimeError(str(e))
else:
raise RuntimeError("Could not get share link from Gradio API Server.")
raise RuntimeError(str(e)) from e
raise RuntimeError("Could not get share link from Gradio API Server.")
def url_ok(url: str) -> bool:

@ -21,10 +21,10 @@ def load_from_pipeline(pipeline: pipelines.base.Pipeline) -> Dict:
try:
import transformers
from transformers import pipelines
except ImportError:
except ImportError as ie:
raise ImportError(
"transformers not installed. Please try `pip install transformers`"
)
) from ie
if not isinstance(pipeline, pipelines.base.Pipeline):
raise ValueError("pipeline must be a transformers.Pipeline")

@ -309,7 +309,7 @@ class Queue:
"headers": dict(websocket.headers),
"query_params": dict(websocket.query_params),
"path_params": dict(websocket.path_params),
"client": dict(host=websocket.client.host, port=websocket.client.port), # type: ignore
"client": {"host": websocket.client.host, "port": websocket.client.port}, # type: ignore
}
async def call_prediction(self, events: List[Event], batch: bool):
@ -444,7 +444,7 @@ class Queue:
event.websocket.send_json(data=data), timeout=timeout
)
return True
except:
except Exception:
await self.clean_event(event)
return False

@ -93,8 +93,10 @@ class RangedFileResponse(Response):
try:
stat_result = await aio_stat(self.path)
self.stat_result = stat_result
except FileNotFoundError:
raise RuntimeError(f"File at path {self.path} does not exist.")
except FileNotFoundError as fnfe:
raise RuntimeError(
f"File at path {self.path} does not exist."
) from fnfe
else:
mode = stat_result.st_mode
if not stat.S_ISREG(mode):

@ -238,17 +238,17 @@ class App(FastAPI):
template,
{"request": request, "config": config},
)
except TemplateNotFound:
except TemplateNotFound as err:
if blocks.share:
raise ValueError(
"Did you install Gradio from source files? Share mode only "
"works when Gradio is installed through the pip package."
)
) from err
else:
raise ValueError(
"Did you install Gradio from source files? You need to build "
"the frontend by running /scripts/build_frontend.sh"
)
) from err
@app.get("/info/", dependencies=[Depends(login_check)])
@app.get("/info", dependencies=[Depends(login_check)])
@ -378,7 +378,7 @@ class App(FastAPI):
# the job being cancelled will not overwrite the state of the iterator.
# In all cases, should_reset will be the empty set the next time
# the fn_index is run.
app.iterators[body.session_hash]["should_reset"] = set([])
app.iterators[body.session_hash]["should_reset"] = set()
else:
session_state = {}
iterators = {}
@ -520,7 +520,7 @@ class App(FastAPI):
# Continuous events are not put in the queue so that they do not
# occupy the queue's resource as they are expected to run forever
if blocks.dependencies[event.fn_index].get("every", 0):
await cancel_tasks(set([f"{event.session_hash}_{event.fn_index}"]))
await cancel_tasks({f"{event.session_hash}_{event.fn_index}"})
await blocks._queue.reset_iterators(event.session_hash, event.fn_index)
task = run_coro_in_background(
blocks._queue.process_events, [event], False
@ -593,9 +593,9 @@ class App(FastAPI):
def safe_join(directory: str, path: str) -> str:
"""Safely path to a base directory to avoid escaping the base directory.
Borrowed from: werkzeug.security.safe_join"""
_os_alt_seps: List[str] = list(
_os_alt_seps: List[str] = [
sep for sep in [os.path.sep, os.path.altsep] if sep is not None and sep != "/"
)
]
if path == "":
raise HTTPException(400)
@ -735,8 +735,10 @@ class Request:
else:
try:
obj = self.kwargs[name]
except KeyError:
raise AttributeError(f"'Request' object has no attribute '{name}'")
except KeyError as ke:
raise AttributeError(
f"'Request' object has no attribute '{name}'"
) from ke
return self.dict_to_obj(obj)

@ -495,32 +495,6 @@ with gr.Blocks(theme=gr.themes.Base(), css=css, title="Gradio Theme Builder") as
def load_theme(theme_name):
theme = [theme for theme in themes if theme.__name__ == theme_name][0]
expand_color = lambda color: list(
[
color.c50,
color.c100,
color.c200,
color.c300,
color.c400,
color.c500,
color.c600,
color.c700,
color.c800,
color.c900,
color.c950,
]
)
expand_size = lambda size: list(
[
size.xxs,
size.xs,
size.sm,
size.md,
size.lg,
size.xl,
size.xxl,
]
)
parameters = inspect.signature(theme.__init__).parameters
primary_hue = parameters["primary_hue"].default
secondary_hue = parameters["secondary_hue"].default
@ -537,14 +511,9 @@ with gr.Blocks(theme=gr.themes.Base(), css=css, title="Gradio Theme Builder") as
font_mono_is_google = [
isinstance(f, gr.themes.GoogleFont) for f in font_mono
]
font = [f.name for f in font]
font_mono = [f.name for f in font_mono]
pad_to_4 = lambda x: x + [None] * (4 - len(x))
font, font_is_google = pad_to_4(font), pad_to_4(font_is_google)
font_mono, font_mono_is_google = pad_to_4(font_mono), pad_to_4(
font_mono_is_google
)
def pad_to_4(x):
return x + [None] * (4 - len(x))
var_output = []
for variable in flat_variables:
@ -555,17 +524,17 @@ with gr.Blocks(theme=gr.themes.Base(), css=css, title="Gradio Theme Builder") as
return (
[primary_hue.name, secondary_hue.name, neutral_hue.name]
+ expand_color(primary_hue)
+ expand_color(secondary_hue)
+ expand_color(neutral_hue)
+ primary_hue.expand()
+ secondary_hue.expand()
+ neutral_hue.expand()
+ [text_size.name, spacing_size.name, radius_size.name]
+ expand_size(text_size)
+ expand_size(spacing_size)
+ expand_size(radius_size)
+ font
+ font_is_google
+ font_mono
+ font_mono_is_google
+ text_size.expand()
+ spacing_size.expand()
+ radius_size.expand()
+ pad_to_4([f.name for f in font])
+ pad_to_4(font_is_google)
+ pad_to_4(font_mono)
+ pad_to_4(font_mono_is_google)
+ var_output
)
@ -831,9 +800,7 @@ with gr.Blocks(theme=theme) as demo:
font_mono=final_mono_fonts,
)
theme.set(
**{attr: val for attr, val in zip(flat_variables, remaining_args)}
)
theme.set(**dict(zip(flat_variables, remaining_args)))
new_step = (base_theme, args)
if len(history) == 0 or str(history[-1]) != str(new_step):
history.append(new_step)

@ -33,6 +33,21 @@ class Color:
self.name = name
Color.all.append(self)
def expand(self) -> list[str]:
return [
self.c50,
self.c100,
self.c200,
self.c300,
self.c400,
self.c500,
self.c600,
self.c700,
self.c800,
self.c900,
self.c950,
]
slate = Color(
name="slate",

@ -1,3 +1,6 @@
from __future__ import annotations
class Size:
all = []
@ -14,6 +17,9 @@ class Size:
self.name = name
Size.all.append(self)
def expand(self) -> list[str]:
return [self.xxs, self.xs, self.sm, self.md, self.lg, self.xl, self.xxl]
radius_none = Size(
name="radius_none",

@ -82,7 +82,7 @@ def version_check():
warnings.warn("unable to parse version details from package URL.")
except KeyError:
warnings.warn("package URL does not contain version info.")
except:
except Exception:
pass
@ -263,7 +263,7 @@ def sagemaker_check() -> bool:
client = boto3.client("sts")
response = client.get_caller_identity()
return "sagemaker" in response["Arn"].lower()
except:
except Exception:
return False
@ -313,7 +313,7 @@ def launch_counter() -> None:
print(en["BETA_INVITE"])
with open(JSON_PATH, "w") as j:
j.write(json.dumps(launches))
except:
except Exception:
pass
@ -488,7 +488,7 @@ def async_iteration(iterator):
return next(iterator)
except StopIteration:
# raise a ValueError here because co-routines can't raise StopIteration themselves
raise StopAsyncIteration()
raise StopAsyncIteration() from None
class AsyncRequest:
@ -825,7 +825,7 @@ def get_cancel_function(
]
async def cancel(session_hash: str) -> None:
task_ids = set([f"{session_hash}_{fn}" for fn in fn_to_comp])
task_ids = {f"{session_hash}_{fn}" for fn in fn_to_comp}
await cancel_tasks(task_ids)
return (

@ -64,14 +64,21 @@ include = [
[tool.ruff]
target-version = "py37"
extend-select = [
"B",
"C",
"I",
# Formatting-related UP rules
"UP030",
"UP031",
"UP032",
]
ignore = [
"C901", # function is too complex (TODO: un-ignore this)
"B023", # function definition in loop (TODO: un-ignore this)
"B008", # function call in argument defaults
"B017", # pytest.raises considered evil
"B028", # explicit stacklevel for warnings
"E501", # from scripts/lint_backend.sh
"E722", # from scripts/lint_backend.sh
"E731", # from scripts/lint_backend.sh
"F403", # from scripts/lint_backend.sh
"F541", # from scripts/lint_backend.sh
]
[tool.ruff.per-file-ignores]

@ -197,7 +197,7 @@ respx==0.19.2
# via -r requirements.in
rfc3986[idna2008]==1.5.0
# via httpx
ruff==0.0.260
ruff==0.0.263
# via -r requirements.in
s3transfer==0.6.0
# via boto3

@ -185,7 +185,7 @@ requests==2.28.1
# transformers
respx==0.19.2
# via -r requirements.in
ruff==0.0.260
ruff==0.0.263
# via -r requirements.in
rfc3986[idna2008]==1.5.0
# via httpx

@ -87,10 +87,11 @@ class TestBlocksMethods:
def fake_func():
return "Hello There"
xray_model = lambda diseases, img: {
disease: random.random() for disease in diseases
}
ct_model = lambda diseases, img: {disease: 0.1 for disease in diseases}
def xray_model(diseases, img):
return {disease: random.random() for disease in diseases}
def ct_model(diseases, img):
return {disease: 0.1 for disease in diseases}
with gr.Blocks() as demo:
gr.Markdown(
@ -405,7 +406,7 @@ class TestComponentsInBlocks:
assert all(dependencies_on_load)
assert len(dependencies_on_load) == 2
# Queue should be explicitly false for these events
assert all([dep["queue"] is False for dep in demo.config["dependencies"]])
assert all(dep["queue"] is False for dep in demo.config["dependencies"])
def test_io_components_attach_load_events_when_value_is_fn(self, io_components):
io_components = [comp for comp in io_components if comp not in [gr.State]]
@ -419,7 +420,7 @@ class TestComponentsInBlocks:
dep for dep in interface.config["dependencies"] if dep["trigger"] == "load"
]
assert len(dependencies_on_load) == len(io_components)
assert all([dep["every"] == 1 for dep in dependencies_on_load])
assert all(dep["every"] == 1 for dep in dependencies_on_load)
def test_get_load_events(self, io_components):
components = []
@ -451,7 +452,7 @@ class TestBlocksPostprocessing:
0, [gr.update(value=None) for _ in io_components], state={}
)
assert all(
[o["value"] == c.postprocess(None) for o, c in zip(output, io_components)]
o["value"] == c.postprocess(None) for o, c in zip(output, io_components)
)
def test_blocks_does_not_replace_keyword_literal(self):
@ -1213,7 +1214,7 @@ class TestEvery:
# If the continuous event got pushed to the queue, the size would be nonzero
# asserting false will terminate the test
if status.json()["queue_size"] != 0:
assert False
raise AssertionError()
else:
break
@ -1275,7 +1276,7 @@ class TestProgressBar:
for _ in prog.tqdm(range(4), unit="iter"):
time.sleep(0.25)
time.sleep(1)
for i in tqdm(["a", "b", "c"], desc="alphabet"):
for _ in tqdm(["a", "b", "c"], desc="alphabet"):
time.sleep(0.25)
return f"Hello, {s}!"
@ -1331,7 +1332,7 @@ class TestProgressBar:
for _ in prog.tqdm(range(4), unit="iter"):
time.sleep(0.25)
time.sleep(1)
for i in tqdm(["a", "b", "c"], desc="alphabet"):
for _ in tqdm(["a", "b", "c"], desc="alphabet"):
time.sleep(0.25)
return f"Hello, {s}!"

@ -1473,7 +1473,7 @@ class TestNames:
# This test ensures that `components.get_component_instance()` works correctly when instantiating from components
def test_no_duplicate_uncased_names(self):
subclasses = gr.components.Component.__subclasses__()
unique_subclasses_uncased = set([s.__name__.lower() for s in subclasses])
unique_subclasses_uncased = {s.__name__.lower() for s in subclasses}
assert len(subclasses) == len(unique_subclasses_uncased)
@ -2152,7 +2152,7 @@ def test_dataset_calls_as_example(*mocks):
]
],
)
assert all([m.called for m in mocks])
assert all(m.called for m in mocks)
cars = vega_datasets.data.cars()
@ -2194,7 +2194,7 @@ class TestScatterPlot:
x_title="Horse",
)
output = plot.postprocess(cars)
assert sorted(list(output.keys())) == ["chart", "plot", "type"]
assert sorted(output.keys()) == ["chart", "plot", "type"]
config = json.loads(output["plot"])
assert config["encoding"]["x"]["field"] == "Horsepower"
assert config["encoding"]["x"]["title"] == "Horse"
@ -2215,7 +2215,7 @@ class TestScatterPlot:
x="Horsepower", y="Miles_per_Gallon", tooltip="Name", interactive=False
)
output = plot.postprocess(cars)
assert sorted(list(output.keys())) == ["chart", "plot", "type"]
assert sorted(output.keys()) == ["chart", "plot", "type"]
config = json.loads(output["plot"])
assert "selection" not in config
@ -2224,7 +2224,7 @@ class TestScatterPlot:
x="Horsepower", y="Miles_per_Gallon", height=100, width=200
)
output = plot.postprocess(cars)
assert sorted(list(output.keys())) == ["chart", "plot", "type"]
assert sorted(output.keys()) == ["chart", "plot", "type"]
config = json.loads(output["plot"])
assert config["height"] == 100
assert config["width"] == 200
@ -2379,7 +2379,7 @@ class TestLinePlot:
x_title="Trading Day",
)
output = plot.postprocess(stocks)
assert sorted(list(output.keys())) == ["chart", "plot", "type"]
assert sorted(output.keys()) == ["chart", "plot", "type"]
config = json.loads(output["plot"])
for layer in config["layer"]:
assert layer["mark"]["type"] in ["line", "point"]
@ -2394,7 +2394,7 @@ class TestLinePlot:
def test_height_width(self):
plot = gr.LinePlot(x="date", y="price", height=100, width=200)
output = plot.postprocess(stocks)
assert sorted(list(output.keys())) == ["chart", "plot", "type"]
assert sorted(output.keys()) == ["chart", "plot", "type"]
config = json.loads(output["plot"])
assert config["height"] == 100
assert config["width"] == 200
@ -2543,7 +2543,7 @@ class TestBarPlot:
x_title="Variable A",
)
output = plot.postprocess(simple)
assert sorted(list(output.keys())) == ["chart", "plot", "type"]
assert sorted(output.keys()) == ["chart", "plot", "type"]
assert output["chart"] == "bar"
config = json.loads(output["plot"])
assert config["encoding"]["x"]["field"] == "a"
@ -2558,7 +2558,7 @@ class TestBarPlot:
def test_height_width(self):
plot = gr.BarPlot(x="a", y="b", height=100, width=200)
output = plot.postprocess(simple)
assert sorted(list(output.keys())) == ["chart", "plot", "type"]
assert sorted(output.keys()) == ["chart", "plot", "type"]
config = json.loads(output["plot"])
assert config["height"] == 100
assert config["width"] == 200

@ -265,7 +265,7 @@ class TestLoadInterface:
):
pass
else:
assert False
raise AssertionError()
else:
assert resp.json()["data"] is not None
finally:
@ -345,11 +345,8 @@ class TestLoadInterfaceWithExamples:
def test_root_url(self):
demo = gr.load("spaces/gradio/test-loading-examples")
assert all(
[
c["props"]["root_url"]
== "https://gradio-test-loading-examples.hf.space/"
for c in demo.get_config_file()["components"]
]
c["props"]["root_url"] == "https://gradio-test-loading-examples.hf.space/"
for c in demo.get_config_file()["components"]
)
def test_root_url_deserialization(self):
@ -427,7 +424,7 @@ def check_dataframe(config):
def check_dataset(config, readme_examples):
# No Examples
if not any(readme_examples.values()):
assert not any([c for c in config["components"] if c["type"] == "dataset"])
assert not any(c for c in config["components"] if c["type"] == "dataset")
else:
dataset = next(c for c in config["components"] if c["type"] == "dataset")
assert dataset["props"]["samples"] == [[cols_to_rows(readme_examples)[1]]]

@ -82,7 +82,7 @@ class TestInterface:
)
interface = Interface(lambda x: 3 * x, "number", "number", examples=path)
dataset_check = any(
[c["type"] == "dataset" for c in interface.get_config_file()["components"]]
c["type"] == "dataset" for c in interface.get_config_file()["components"]
)
assert dataset_check
@ -117,7 +117,9 @@ class TestInterface:
interface.close()
def test_interface_representation(self):
prediction_fn = lambda x: x
def prediction_fn(x):
return x
prediction_fn.__name__ = "prediction_fn"
repr = str(Interface(prediction_fn, "textbox", "label")).split("\n")
assert prediction_fn.__name__ in repr[0]

@ -12,10 +12,13 @@ from gradio.processing_utils import decode_base64_to_image
os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"
def max_word_len(text: str) -> int:
return max([len(word) for word in text.split(" ")])
class TestDefault:
@pytest.mark.asyncio
async def test_default_text(self):
max_word_len = lambda text: max([len(word) for word in text.split(" ")])
text_interface = Interface(
max_word_len, "textbox", "label", interpretation="default"
)
@ -29,7 +32,6 @@ class TestDefault:
class TestShapley:
@pytest.mark.asyncio
async def test_shapley_text(self):
max_word_len = lambda text: max([len(word) for word in text.split(" ")])
text_interface = Interface(
max_word_len, "textbox", "label", interpretation="shapley"
)
@ -42,8 +44,9 @@ class TestShapley:
class TestCustom:
@pytest.mark.asyncio
async def test_custom_text(self):
max_word_len = lambda text: max([len(word) for word in text.split(" ")])
custom = lambda text: [(char, 1) for char in text]
def custom(text):
return [(char, 1) for char in text]
text_interface = Interface(
max_word_len, "textbox", "label", interpretation=custom
)
@ -54,8 +57,12 @@ class TestCustom:
@pytest.mark.asyncio
async def test_custom_img(self):
max_pixel_value = lambda img: img.max()
custom = lambda img: img.tolist()
def max_pixel_value(img):
return img.max()
def custom(img):
return img.tolist()
img_interface = Interface(
max_pixel_value, "image", "label", interpretation=custom
)

@ -265,8 +265,8 @@ class TestQueueProcessEvents:
# setting up the function to expect further iterative responses.
# Then we provide a 500 response.
side_effects = [
MagicMock(has_exception=False, status=200, json=dict(is_generating=True)),
MagicMock(has_exception=False, status=500, json=dict(error="Foo")),
MagicMock(has_exception=False, status=200, json={"is_generating": True}),
MagicMock(has_exception=False, status=500, json={"error": "Foo"}),
]
mock_event.disconnect = AsyncMock()
queue.gather_event_data = AsyncMock(return_value=True)
@ -301,7 +301,7 @@ class TestQueueProcessEvents:
ValueError("Can't connect"),
]
queue.call_prediction = AsyncMock(
return_value=MagicMock(has_exception=False, json=dict(is_generating=False))
return_value=MagicMock(has_exception=False, json={"is_generating": False})
)
mock_event.disconnect = AsyncMock()
queue.clean_event = AsyncMock()
@ -326,7 +326,7 @@ class TestQueueProcessEvents:
mock_event.websocket.receive_json.return_value = {"data": ["test"], "fn": 0}
mock_event.websocket.send_json = AsyncMock()
queue.call_prediction = AsyncMock(
return_value=MagicMock(has_exception=False, json=dict(is_generating=False))
return_value=MagicMock(has_exception=False, json={"is_generating": False})
)
# No exception should be raised during `process_event`
mock_event.disconnect = AsyncMock(side_effect=ValueError("..."))
@ -373,7 +373,7 @@ class TestQueueBatch:
mock_event.disconnect.assert_called_once()
mock_event2.disconnect.assert_called_once()
queue.clean_event.call_count == 2
assert queue.clean_event.call_count == 2
class TestGetEventsInBatch:

@ -390,12 +390,12 @@ class TestAuthenticatedRoutes:
response = client.post(
"/login",
data=dict(username="test", password="correct_password"),
data={"username": "test", "password": "correct_password"},
)
assert response.status_code == 200
response = client.post(
"/login",
data=dict(username="test", password="incorrect_password"),
data={"username": "test", "password": "incorrect_password"},
)
assert response.status_code == 400
@ -524,7 +524,7 @@ class TestPassingRequest:
client.post(
"/login",
data=dict(username="admin", password="password"),
data={"username": "admin", "password": "password"},
)
response = client.post("/api/predict/", json={"data": ["test"]})
assert response.status_code == 200