mirror of
https://github.com/gradio-app/gradio.git
synced 2025-04-18 12:50:30 +08:00
Ruff update + strictening (#3979)
* Update ruff to 0.0.263 * Get rid of bare except:s * Fix two Ruff E731s by moving expand_{color,size} to the relevant classes * Fix Ruff E731 and some variable shadowing in theme builder * Fix remaining Ruff E731s * Get rid of unused Ruff ignores * Fix ruff B904 issues (raise from) * Fix Ruff B007: loop-control variable not used * Fix Ruff B011 (do not assert false) * Remove unused args and kwargs from Progress.tqdm() (spotted via Ruff B026) * Fix mutable argument default in CheckboxGroup * Noqa ABC-related lint warnings for Interpretable * Add missed assert in test_queueing (ruff B015) * Enable ruff B * Enable ruff C and fix issues * Add changelog * Add UP03[012] after #3984 --------- Co-authored-by: Abubakar Abid <abubakar@huggingface.co>
This commit is contained in:
parent
1f9584f9a7
commit
12a97746ff
@ -88,6 +88,7 @@ No changes to highlight.
|
||||
|
||||
- CI: Python backend lint is only run once, by [@akx](https://github.com/akx) in [PR 3960](https://github.com/gradio-app/gradio/pull/3960)
|
||||
- Format invocations and concatenations were replaced by f-strings where possible by [@akx](https://github.com/akx) in [PR 3984](https://github.com/gradio-app/gradio/pull/3984)
|
||||
- Linting rules were made more strict and issues fixed by [@akx](https://github.com/akx) in [PR 3979](https://github.com/gradio-app/gradio/pull/3979).
|
||||
|
||||
## Breaking Changes:
|
||||
|
||||
|
@ -169,10 +169,10 @@ class Client:
|
||||
"""
|
||||
try:
|
||||
original_info = huggingface_hub.get_space_runtime(from_id, token=hf_token)
|
||||
except RepositoryNotFoundError:
|
||||
except RepositoryNotFoundError as rnfe:
|
||||
raise ValueError(
|
||||
f"Could not find Space: {from_id}. If it is a private Space, please provide an `hf_token`."
|
||||
)
|
||||
) from rnfe
|
||||
if to_id:
|
||||
if "/" in to_id:
|
||||
to_id = to_id.split("/")[1]
|
||||
@ -554,8 +554,10 @@ class Client:
|
||||
result = re.search(r"window.gradio_config = (.*?);[\s]*</script>", r.text)
|
||||
try:
|
||||
config = json.loads(result.group(1)) # type: ignore
|
||||
except AttributeError:
|
||||
raise ValueError(f"Could not get Gradio config from: {self.src}")
|
||||
except AttributeError as ae:
|
||||
raise ValueError(
|
||||
f"Could not get Gradio config from: {self.src}"
|
||||
) from ae
|
||||
if "allow_flagging" in config:
|
||||
raise ValueError(
|
||||
"Gradio 2.x is not supported by this client. Please upgrade your Gradio app to Gradio 3.x or higher."
|
||||
@ -639,20 +641,22 @@ class Endpoint:
|
||||
result = json.loads(response.content.decode("utf-8"))
|
||||
try:
|
||||
output = result["data"]
|
||||
except KeyError:
|
||||
except KeyError as ke:
|
||||
is_public_space = (
|
||||
self.client.space_id
|
||||
and not huggingface_hub.space_info(self.client.space_id).private
|
||||
)
|
||||
if "error" in result and "429" in result["error"] and is_public_space:
|
||||
raise utils.TooManyRequestsError(
|
||||
f"Too many requests to the API, please try again later. To avoid being rate-limited, please duplicate the Space using Client.duplicate({self.client.space_id}) and pass in your Hugging Face token."
|
||||
)
|
||||
f"Too many requests to the API, please try again later. To avoid being rate-limited, "
|
||||
f"please duplicate the Space using Client.duplicate({self.client.space_id}) "
|
||||
f"and pass in your Hugging Face token."
|
||||
) from None
|
||||
elif "error" in result:
|
||||
raise ValueError(result["error"])
|
||||
raise ValueError(result["error"]) from None
|
||||
raise KeyError(
|
||||
f"Could not find 'data' key in response. Response received: {result}"
|
||||
)
|
||||
) from ke
|
||||
return tuple(output)
|
||||
|
||||
return _predict
|
||||
|
@ -463,11 +463,11 @@ def set_space_timeout(
|
||||
)
|
||||
try:
|
||||
huggingface_hub.utils.hf_raise_for_status(r)
|
||||
except huggingface_hub.utils.HfHubHTTPError:
|
||||
except huggingface_hub.utils.HfHubHTTPError as err:
|
||||
raise SpaceDuplicationError(
|
||||
f"Could not set sleep timeout on duplicated Space. Please visit {SPACE_URL.format(space_id)} "
|
||||
"to set a timeout manually to reduce billing charges."
|
||||
)
|
||||
) from err
|
||||
|
||||
|
||||
########################
|
||||
|
@ -1,6 +1,6 @@
|
||||
black==22.6.0
|
||||
pytest-asyncio
|
||||
pytest==7.1.2
|
||||
ruff==0.0.260
|
||||
ruff==0.0.263
|
||||
pyright==1.1.298
|
||||
gradio
|
||||
|
@ -987,12 +987,7 @@ class Blocks(BlockContext):
|
||||
is_generating = False
|
||||
|
||||
if block_fn.inputs_as_dict:
|
||||
processed_input = [
|
||||
{
|
||||
input_component: data
|
||||
for input_component, data in zip(block_fn.inputs, processed_input)
|
||||
}
|
||||
]
|
||||
processed_input = [dict(zip(block_fn.inputs, processed_input))]
|
||||
|
||||
if isinstance(requests, list):
|
||||
request = requests[0]
|
||||
@ -1211,10 +1206,11 @@ Received outputs:
|
||||
if predictions[i] is components._Keywords.FINISHED_ITERATING:
|
||||
output.append(None)
|
||||
continue
|
||||
except (IndexError, KeyError):
|
||||
except (IndexError, KeyError) as err:
|
||||
raise ValueError(
|
||||
f"Number of output components does not match number of values returned from from function {block_fn.name}"
|
||||
)
|
||||
"Number of output components does not match number "
|
||||
f"of values returned from from function {block_fn.name}"
|
||||
) from err
|
||||
block = self.blocks[output_id]
|
||||
if getattr(block, "stateful", False):
|
||||
if not utils.is_update(predictions[i]):
|
||||
|
@ -1846,10 +1846,10 @@ class Image(
|
||||
resized_and_cropped_image = np.array(x)
|
||||
try:
|
||||
from skimage.segmentation import slic
|
||||
except (ImportError, ModuleNotFoundError):
|
||||
except (ImportError, ModuleNotFoundError) as err:
|
||||
raise ValueError(
|
||||
"Error: running this interpretation for images requires scikit-image, please install it first."
|
||||
)
|
||||
) from err
|
||||
try:
|
||||
segments_slic = slic(
|
||||
resized_and_cropped_image,
|
||||
@ -1880,7 +1880,7 @@ class Image(
|
||||
segments_slic, resized_and_cropped_image = self._segment_by_slic(x)
|
||||
tokens, masks, leave_one_out_tokens = [], [], []
|
||||
replace_color = np.mean(resized_and_cropped_image, axis=(0, 1))
|
||||
for i, segment_value in enumerate(np.unique(segments_slic)):
|
||||
for segment_value in np.unique(segments_slic):
|
||||
mask = segments_slic == segment_value
|
||||
image_screen = np.copy(resized_and_cropped_image)
|
||||
image_screen[segments_slic == segment_value] = replace_color
|
||||
@ -3863,10 +3863,11 @@ class HighlightedText(Changeable, Selectable, IOComponent, JSONSerializable):
|
||||
try:
|
||||
text = y["text"]
|
||||
entities = y["entities"]
|
||||
except KeyError:
|
||||
except KeyError as ke:
|
||||
raise ValueError(
|
||||
"Expected a dictionary with keys 'text' and 'entities' for the value of the HighlightedText component."
|
||||
)
|
||||
"Expected a dictionary with keys 'text' and 'entities' "
|
||||
"for the value of the HighlightedText component."
|
||||
) from ke
|
||||
if len(entities) == 0:
|
||||
y = [(text, None)]
|
||||
else:
|
||||
@ -4049,7 +4050,7 @@ class AnnotatedImage(Selectable, IOComponent, JSONSerializable):
|
||||
def hex_to_rgb(value):
|
||||
value = value.lstrip("#")
|
||||
lv = len(value)
|
||||
return list(int(value[i : i + lv // 3], 16) for i in range(0, lv, lv // 3))
|
||||
return [int(value[i : i + lv // 3], 16) for i in range(0, lv, lv // 3)]
|
||||
|
||||
for mask, label in y[1]:
|
||||
mask_array = np.zeros((base_img.shape[0], base_img.shape[1]))
|
||||
@ -5142,18 +5143,18 @@ class ScatterPlot(Plot):
|
||||
):
|
||||
"""Helper for creating the scatter plot."""
|
||||
interactive = True if interactive is None else interactive
|
||||
encodings = dict(
|
||||
x=alt.X(
|
||||
encodings = {
|
||||
"x": alt.X(
|
||||
x, # type: ignore
|
||||
title=x_title or x, # type: ignore
|
||||
scale=AltairPlot.create_scale(x_lim), # type: ignore
|
||||
), # ignore: type
|
||||
y=alt.Y(
|
||||
"y": alt.Y(
|
||||
y, # type: ignore
|
||||
title=y_title or y, # type: ignore
|
||||
scale=AltairPlot.create_scale(y_lim), # type: ignore
|
||||
),
|
||||
)
|
||||
}
|
||||
properties = {}
|
||||
if title:
|
||||
properties["title"] = title
|
||||
@ -5473,18 +5474,18 @@ class LinePlot(Plot):
|
||||
):
|
||||
"""Helper for creating the scatter plot."""
|
||||
interactive = True if interactive is None else interactive
|
||||
encodings = dict(
|
||||
x=alt.X(
|
||||
encodings = {
|
||||
"x": alt.X(
|
||||
x, # type: ignore
|
||||
title=x_title or x, # type: ignore
|
||||
scale=AltairPlot.create_scale(x_lim), # type: ignore
|
||||
),
|
||||
y=alt.Y(
|
||||
"y": alt.Y(
|
||||
y, # type: ignore
|
||||
title=y_title or y, # type: ignore
|
||||
scale=AltairPlot.create_scale(y_lim), # type: ignore
|
||||
),
|
||||
)
|
||||
}
|
||||
properties = {}
|
||||
if title:
|
||||
properties["title"] = title
|
||||
@ -5796,10 +5797,10 @@ class BarPlot(Plot):
|
||||
y_lim: List[int] | None = None,
|
||||
interactive: bool | None = True,
|
||||
):
|
||||
"""Helper for creating the scatter plot."""
|
||||
"""Helper for creating the bar plot."""
|
||||
interactive = True if interactive is None else interactive
|
||||
orientation = (
|
||||
dict(field=group, title=group_title if group_title is not None else group)
|
||||
{"field": group, "title": group_title if group_title is not None else group}
|
||||
if group
|
||||
else {}
|
||||
)
|
||||
@ -6124,7 +6125,7 @@ class Dataset(Clickable, Selectable, Component, StringSerializable):
|
||||
|
||||
# Narrow type to IOComponent
|
||||
assert all(
|
||||
[isinstance(c, IOComponent) for c in self.components]
|
||||
isinstance(c, IOComponent) for c in self.components
|
||||
), "All components in a `Dataset` must be subclasses of `IOComponent`"
|
||||
self.components = [c for c in self.components if isinstance(c, IOComponent)]
|
||||
for component in self.components:
|
||||
@ -6138,7 +6139,7 @@ class Dataset(Clickable, Selectable, Component, StringSerializable):
|
||||
self.label = label
|
||||
if headers is not None:
|
||||
self.headers = headers
|
||||
elif all([c.label is None for c in self.components]):
|
||||
elif all(c.label is None for c in self.components):
|
||||
self.headers = []
|
||||
else:
|
||||
self.headers = [c.label or "" for c in self.components]
|
||||
|
@ -434,8 +434,8 @@ def from_spaces(
|
||||
) # some basic regex to extract the config
|
||||
try:
|
||||
config = json.loads(result.group(1)) # type: ignore
|
||||
except AttributeError:
|
||||
raise ValueError(f"Could not load the Space: {space_name}")
|
||||
except AttributeError as ae:
|
||||
raise ValueError(f"Could not load the Space: {space_name}") from ae
|
||||
if "allow_flagging" in config: # Create an Interface for Gradio 2.x Spaces
|
||||
return from_spaces_interface(
|
||||
space_name, config, alias, api_key, iframe_url, **kwargs
|
||||
@ -477,14 +477,14 @@ def from_spaces_interface(
|
||||
data = json.dumps({"data": data})
|
||||
response = requests.post(api_url, headers=headers, data=data)
|
||||
result = json.loads(response.content.decode("utf-8"))
|
||||
if "error" in result and "429" in result["error"]:
|
||||
raise TooManyRequestsError("Too many requests to the Hugging Face API")
|
||||
try:
|
||||
output = result["data"]
|
||||
except KeyError:
|
||||
if "error" in result and "429" in result["error"]:
|
||||
raise TooManyRequestsError("Too many requests to the Hugging Face API")
|
||||
except KeyError as ke:
|
||||
raise KeyError(
|
||||
f"Could not find 'data' key in response from external Space. Response received: {result}"
|
||||
)
|
||||
) from ke
|
||||
if (
|
||||
len(config["outputs"]) == 1
|
||||
): # if the fn is supposed to return a single value, pop it
|
||||
|
@ -99,12 +99,13 @@ def encode_to_base64(r: requests.Response) -> str:
|
||||
# Case 2: the data prefix is a key in the response
|
||||
if content_type == "application/json":
|
||||
try:
|
||||
content_type = r.json()[0]["content-type"]
|
||||
base64_repr = r.json()[0]["blob"]
|
||||
except KeyError:
|
||||
data = r.json()[0]
|
||||
content_type = data["content-type"]
|
||||
base64_repr = data["blob"]
|
||||
except KeyError as ke:
|
||||
raise ValueError(
|
||||
"Cannot determine content type returned" "by external API."
|
||||
)
|
||||
"Cannot determine content type returned by external API."
|
||||
) from ke
|
||||
# Case 3: the data prefix is included in the response headers
|
||||
else:
|
||||
pass
|
||||
|
@ -276,11 +276,11 @@ class HuggingFaceDatasetSaver(FlaggingCallback):
|
||||
"""
|
||||
try:
|
||||
import huggingface_hub
|
||||
except (ImportError, ModuleNotFoundError):
|
||||
except (ImportError, ModuleNotFoundError) as err:
|
||||
raise ImportError(
|
||||
"Package `huggingface_hub` not found is needed "
|
||||
"for HuggingFaceDatasetSaver. Try 'pip install huggingface_hub'."
|
||||
)
|
||||
) from err
|
||||
hh_version = pkg_resources.get_distribution("huggingface_hub").version
|
||||
try:
|
||||
if StrictVersion(hh_version) < StrictVersion("0.6.0"):
|
||||
@ -416,11 +416,11 @@ class HuggingFaceDatasetJSONSaver(FlaggingCallback):
|
||||
"""
|
||||
try:
|
||||
import huggingface_hub
|
||||
except (ImportError, ModuleNotFoundError):
|
||||
except (ImportError, ModuleNotFoundError) as err:
|
||||
raise ImportError(
|
||||
"Package `huggingface_hub` not found is needed "
|
||||
"for HuggingFaceDatasetJSONSaver. Try 'pip install huggingface_hub'."
|
||||
)
|
||||
) from err
|
||||
hh_version = pkg_resources.get_distribution("huggingface_hub").version
|
||||
try:
|
||||
if StrictVersion(hh_version) < StrictVersion("0.6.0"):
|
||||
@ -510,9 +510,7 @@ class HuggingFaceDatasetJSONSaver(FlaggingCallback):
|
||||
csv_data.append(flag_option)
|
||||
|
||||
# Creates metadata dict from row data and dumps it
|
||||
metadata_dict = {
|
||||
header: _csv_data for header, _csv_data in zip(headers, csv_data)
|
||||
}
|
||||
metadata_dict = dict(zip(headers, csv_data))
|
||||
self.dump_json(metadata_dict, Path(folder_name) / "metadata.jsonl")
|
||||
|
||||
if is_new:
|
||||
|
@ -422,7 +422,7 @@ class Progress(Iterable):
|
||||
return next(current_iterable.iterable) # type: ignore
|
||||
except StopIteration:
|
||||
self.iterables.pop()
|
||||
raise StopIteration
|
||||
raise
|
||||
else:
|
||||
return self
|
||||
|
||||
@ -463,8 +463,6 @@ class Progress(Iterable):
|
||||
total: int | None = None,
|
||||
unit: str = "steps",
|
||||
_tqdm=None,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Attaches progress tracker to iterable, like tqdm.
|
||||
@ -539,7 +537,7 @@ def create_tracker(root_blocks, event_id, fn, track_tqdm):
|
||||
)
|
||||
if self._progress is not None:
|
||||
self._progress.event_id = event_id
|
||||
self._progress.tqdm(iterable, desc, _tqdm=self, *args, **kwargs)
|
||||
self._progress.tqdm(iterable, desc, _tqdm=self)
|
||||
kwargs["file"] = open(os.devnull, "w")
|
||||
self.__init__orig__(iterable, desc, *args, **kwargs)
|
||||
|
||||
@ -611,7 +609,7 @@ def special_args(
|
||||
"""
|
||||
signature = inspect.signature(fn)
|
||||
positional_args = []
|
||||
for i, param in enumerate(signature.parameters.values()):
|
||||
for param in signature.parameters.values():
|
||||
if param.kind not in (param.POSITIONAL_ONLY, param.POSITIONAL_OR_KEYWORD):
|
||||
break
|
||||
positional_args.append(param)
|
||||
|
@ -133,7 +133,7 @@ class CheckboxGroup(components.CheckboxGroup):
|
||||
def __init__(
|
||||
self,
|
||||
choices: List[str],
|
||||
default: List[str] = [],
|
||||
default: List[str] | None = None,
|
||||
type: str = "value",
|
||||
label: Optional[str] = None,
|
||||
optional: bool = False,
|
||||
@ -146,6 +146,8 @@ class CheckboxGroup(components.CheckboxGroup):
|
||||
label (str): component name in interface.
|
||||
optional (bool): this parameter is ignored.
|
||||
"""
|
||||
if default is None:
|
||||
default = []
|
||||
warnings.warn(
|
||||
"Usage of gradio.inputs is deprecated, and will not be supported in the future, please import your component from gradio.components",
|
||||
)
|
||||
|
@ -349,9 +349,9 @@ class Interface(Blocks):
|
||||
raise ValueError(
|
||||
"flagging_options must be a list of strings or list of (string, string) tuples."
|
||||
)
|
||||
elif all([isinstance(x, str) for x in flagging_options]):
|
||||
elif all(isinstance(x, str) for x in flagging_options):
|
||||
self.flagging_options = [(f"Flag as {x}", x) for x in flagging_options]
|
||||
elif all([isinstance(x, tuple) for x in flagging_options]):
|
||||
elif all(isinstance(x, tuple) for x in flagging_options):
|
||||
self.flagging_options = flagging_options
|
||||
else:
|
||||
raise ValueError(
|
||||
@ -620,7 +620,7 @@ class Interface(Blocks):
|
||||
for output in self.fn(*args):
|
||||
if len(self.output_components) == 1 and not self.batch:
|
||||
output = [output]
|
||||
output = [o for o in output]
|
||||
output = list(output)
|
||||
yield output + [
|
||||
Button.update(visible=False),
|
||||
Button.update(visible=True),
|
||||
|
@ -16,11 +16,11 @@ if TYPE_CHECKING: # Only import for type checking (is False at runtime).
|
||||
from gradio import Interface
|
||||
|
||||
|
||||
class Interpretable(ABC):
|
||||
class Interpretable(ABC): # noqa: B024
|
||||
def __init__(self) -> None:
|
||||
self.set_interpret_parameters()
|
||||
|
||||
def set_interpret_parameters(self):
|
||||
def set_interpret_parameters(self): # noqa: B027
|
||||
"""
|
||||
Set any parameters for interpretation. Properties can be set here to be
|
||||
used in get_interpretation_neighbors and get_interpretation_scores.
|
||||
@ -189,10 +189,10 @@ async def run_interpret(interface: Interface, raw_input: List):
|
||||
elif interp == "shap" or interp == "shapley":
|
||||
try:
|
||||
import shap # type: ignore
|
||||
except (ImportError, ModuleNotFoundError):
|
||||
except (ImportError, ModuleNotFoundError) as err:
|
||||
raise ValueError(
|
||||
"The package `shap` is required for this interpretation method. Try: `pip install shap`"
|
||||
)
|
||||
) from err
|
||||
input_component = interface.input_components[i]
|
||||
if not isinstance(input_component, TokenInterpretable):
|
||||
raise ValueError(
|
||||
|
@ -117,10 +117,11 @@ def start_server(
|
||||
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
||||
s.bind((LOCALHOST_NAME, server_port))
|
||||
s.close()
|
||||
except OSError:
|
||||
except OSError as err:
|
||||
raise OSError(
|
||||
f"Port {server_port} is in use. If a gradio.Blocks is running on the port, you can close() it or gradio.close_all()."
|
||||
)
|
||||
f"Port {server_port} is in use. If a gradio.Blocks is running on the port, "
|
||||
f"you can close() it or gradio.close_all()."
|
||||
) from err
|
||||
port = server_port
|
||||
|
||||
url_host_name = "localhost" if server_name == "0.0.0.0" else server_name
|
||||
@ -173,9 +174,8 @@ def setup_tunnel(local_host: str, local_port: int, share_token: str) -> str:
|
||||
address = tunnel.start_tunnel()
|
||||
return address
|
||||
except Exception as e:
|
||||
raise RuntimeError(str(e))
|
||||
else:
|
||||
raise RuntimeError("Could not get share link from Gradio API Server.")
|
||||
raise RuntimeError(str(e)) from e
|
||||
raise RuntimeError("Could not get share link from Gradio API Server.")
|
||||
|
||||
|
||||
def url_ok(url: str) -> bool:
|
||||
|
@ -21,10 +21,10 @@ def load_from_pipeline(pipeline: pipelines.base.Pipeline) -> Dict:
|
||||
try:
|
||||
import transformers
|
||||
from transformers import pipelines
|
||||
except ImportError:
|
||||
except ImportError as ie:
|
||||
raise ImportError(
|
||||
"transformers not installed. Please try `pip install transformers`"
|
||||
)
|
||||
) from ie
|
||||
if not isinstance(pipeline, pipelines.base.Pipeline):
|
||||
raise ValueError("pipeline must be a transformers.Pipeline")
|
||||
|
||||
|
@ -309,7 +309,7 @@ class Queue:
|
||||
"headers": dict(websocket.headers),
|
||||
"query_params": dict(websocket.query_params),
|
||||
"path_params": dict(websocket.path_params),
|
||||
"client": dict(host=websocket.client.host, port=websocket.client.port), # type: ignore
|
||||
"client": {"host": websocket.client.host, "port": websocket.client.port}, # type: ignore
|
||||
}
|
||||
|
||||
async def call_prediction(self, events: List[Event], batch: bool):
|
||||
@ -444,7 +444,7 @@ class Queue:
|
||||
event.websocket.send_json(data=data), timeout=timeout
|
||||
)
|
||||
return True
|
||||
except:
|
||||
except Exception:
|
||||
await self.clean_event(event)
|
||||
return False
|
||||
|
||||
|
@ -93,8 +93,10 @@ class RangedFileResponse(Response):
|
||||
try:
|
||||
stat_result = await aio_stat(self.path)
|
||||
self.stat_result = stat_result
|
||||
except FileNotFoundError:
|
||||
raise RuntimeError(f"File at path {self.path} does not exist.")
|
||||
except FileNotFoundError as fnfe:
|
||||
raise RuntimeError(
|
||||
f"File at path {self.path} does not exist."
|
||||
) from fnfe
|
||||
else:
|
||||
mode = stat_result.st_mode
|
||||
if not stat.S_ISREG(mode):
|
||||
|
@ -238,17 +238,17 @@ class App(FastAPI):
|
||||
template,
|
||||
{"request": request, "config": config},
|
||||
)
|
||||
except TemplateNotFound:
|
||||
except TemplateNotFound as err:
|
||||
if blocks.share:
|
||||
raise ValueError(
|
||||
"Did you install Gradio from source files? Share mode only "
|
||||
"works when Gradio is installed through the pip package."
|
||||
)
|
||||
) from err
|
||||
else:
|
||||
raise ValueError(
|
||||
"Did you install Gradio from source files? You need to build "
|
||||
"the frontend by running /scripts/build_frontend.sh"
|
||||
)
|
||||
) from err
|
||||
|
||||
@app.get("/info/", dependencies=[Depends(login_check)])
|
||||
@app.get("/info", dependencies=[Depends(login_check)])
|
||||
@ -378,7 +378,7 @@ class App(FastAPI):
|
||||
# the job being cancelled will not overwrite the state of the iterator.
|
||||
# In all cases, should_reset will be the empty set the next time
|
||||
# the fn_index is run.
|
||||
app.iterators[body.session_hash]["should_reset"] = set([])
|
||||
app.iterators[body.session_hash]["should_reset"] = set()
|
||||
else:
|
||||
session_state = {}
|
||||
iterators = {}
|
||||
@ -520,7 +520,7 @@ class App(FastAPI):
|
||||
# Continuous events are not put in the queue so that they do not
|
||||
# occupy the queue's resource as they are expected to run forever
|
||||
if blocks.dependencies[event.fn_index].get("every", 0):
|
||||
await cancel_tasks(set([f"{event.session_hash}_{event.fn_index}"]))
|
||||
await cancel_tasks({f"{event.session_hash}_{event.fn_index}"})
|
||||
await blocks._queue.reset_iterators(event.session_hash, event.fn_index)
|
||||
task = run_coro_in_background(
|
||||
blocks._queue.process_events, [event], False
|
||||
@ -593,9 +593,9 @@ class App(FastAPI):
|
||||
def safe_join(directory: str, path: str) -> str:
|
||||
"""Safely path to a base directory to avoid escaping the base directory.
|
||||
Borrowed from: werkzeug.security.safe_join"""
|
||||
_os_alt_seps: List[str] = list(
|
||||
_os_alt_seps: List[str] = [
|
||||
sep for sep in [os.path.sep, os.path.altsep] if sep is not None and sep != "/"
|
||||
)
|
||||
]
|
||||
|
||||
if path == "":
|
||||
raise HTTPException(400)
|
||||
@ -735,8 +735,10 @@ class Request:
|
||||
else:
|
||||
try:
|
||||
obj = self.kwargs[name]
|
||||
except KeyError:
|
||||
raise AttributeError(f"'Request' object has no attribute '{name}'")
|
||||
except KeyError as ke:
|
||||
raise AttributeError(
|
||||
f"'Request' object has no attribute '{name}'"
|
||||
) from ke
|
||||
return self.dict_to_obj(obj)
|
||||
|
||||
|
||||
|
@ -495,32 +495,6 @@ with gr.Blocks(theme=gr.themes.Base(), css=css, title="Gradio Theme Builder") as
|
||||
def load_theme(theme_name):
|
||||
theme = [theme for theme in themes if theme.__name__ == theme_name][0]
|
||||
|
||||
expand_color = lambda color: list(
|
||||
[
|
||||
color.c50,
|
||||
color.c100,
|
||||
color.c200,
|
||||
color.c300,
|
||||
color.c400,
|
||||
color.c500,
|
||||
color.c600,
|
||||
color.c700,
|
||||
color.c800,
|
||||
color.c900,
|
||||
color.c950,
|
||||
]
|
||||
)
|
||||
expand_size = lambda size: list(
|
||||
[
|
||||
size.xxs,
|
||||
size.xs,
|
||||
size.sm,
|
||||
size.md,
|
||||
size.lg,
|
||||
size.xl,
|
||||
size.xxl,
|
||||
]
|
||||
)
|
||||
parameters = inspect.signature(theme.__init__).parameters
|
||||
primary_hue = parameters["primary_hue"].default
|
||||
secondary_hue = parameters["secondary_hue"].default
|
||||
@ -537,14 +511,9 @@ with gr.Blocks(theme=gr.themes.Base(), css=css, title="Gradio Theme Builder") as
|
||||
font_mono_is_google = [
|
||||
isinstance(f, gr.themes.GoogleFont) for f in font_mono
|
||||
]
|
||||
font = [f.name for f in font]
|
||||
font_mono = [f.name for f in font_mono]
|
||||
pad_to_4 = lambda x: x + [None] * (4 - len(x))
|
||||
|
||||
font, font_is_google = pad_to_4(font), pad_to_4(font_is_google)
|
||||
font_mono, font_mono_is_google = pad_to_4(font_mono), pad_to_4(
|
||||
font_mono_is_google
|
||||
)
|
||||
def pad_to_4(x):
|
||||
return x + [None] * (4 - len(x))
|
||||
|
||||
var_output = []
|
||||
for variable in flat_variables:
|
||||
@ -555,17 +524,17 @@ with gr.Blocks(theme=gr.themes.Base(), css=css, title="Gradio Theme Builder") as
|
||||
|
||||
return (
|
||||
[primary_hue.name, secondary_hue.name, neutral_hue.name]
|
||||
+ expand_color(primary_hue)
|
||||
+ expand_color(secondary_hue)
|
||||
+ expand_color(neutral_hue)
|
||||
+ primary_hue.expand()
|
||||
+ secondary_hue.expand()
|
||||
+ neutral_hue.expand()
|
||||
+ [text_size.name, spacing_size.name, radius_size.name]
|
||||
+ expand_size(text_size)
|
||||
+ expand_size(spacing_size)
|
||||
+ expand_size(radius_size)
|
||||
+ font
|
||||
+ font_is_google
|
||||
+ font_mono
|
||||
+ font_mono_is_google
|
||||
+ text_size.expand()
|
||||
+ spacing_size.expand()
|
||||
+ radius_size.expand()
|
||||
+ pad_to_4([f.name for f in font])
|
||||
+ pad_to_4(font_is_google)
|
||||
+ pad_to_4(font_mono)
|
||||
+ pad_to_4(font_mono_is_google)
|
||||
+ var_output
|
||||
)
|
||||
|
||||
@ -831,9 +800,7 @@ with gr.Blocks(theme=theme) as demo:
|
||||
font_mono=final_mono_fonts,
|
||||
)
|
||||
|
||||
theme.set(
|
||||
**{attr: val for attr, val in zip(flat_variables, remaining_args)}
|
||||
)
|
||||
theme.set(**dict(zip(flat_variables, remaining_args)))
|
||||
new_step = (base_theme, args)
|
||||
if len(history) == 0 or str(history[-1]) != str(new_step):
|
||||
history.append(new_step)
|
||||
|
@ -33,6 +33,21 @@ class Color:
|
||||
self.name = name
|
||||
Color.all.append(self)
|
||||
|
||||
def expand(self) -> list[str]:
|
||||
return [
|
||||
self.c50,
|
||||
self.c100,
|
||||
self.c200,
|
||||
self.c300,
|
||||
self.c400,
|
||||
self.c500,
|
||||
self.c600,
|
||||
self.c700,
|
||||
self.c800,
|
||||
self.c900,
|
||||
self.c950,
|
||||
]
|
||||
|
||||
|
||||
slate = Color(
|
||||
name="slate",
|
||||
|
@ -1,3 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
|
||||
class Size:
|
||||
all = []
|
||||
|
||||
@ -14,6 +17,9 @@ class Size:
|
||||
self.name = name
|
||||
Size.all.append(self)
|
||||
|
||||
def expand(self) -> list[str]:
|
||||
return [self.xxs, self.xs, self.sm, self.md, self.lg, self.xl, self.xxl]
|
||||
|
||||
|
||||
radius_none = Size(
|
||||
name="radius_none",
|
||||
|
@ -82,7 +82,7 @@ def version_check():
|
||||
warnings.warn("unable to parse version details from package URL.")
|
||||
except KeyError:
|
||||
warnings.warn("package URL does not contain version info.")
|
||||
except:
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
@ -263,7 +263,7 @@ def sagemaker_check() -> bool:
|
||||
client = boto3.client("sts")
|
||||
response = client.get_caller_identity()
|
||||
return "sagemaker" in response["Arn"].lower()
|
||||
except:
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
@ -313,7 +313,7 @@ def launch_counter() -> None:
|
||||
print(en["BETA_INVITE"])
|
||||
with open(JSON_PATH, "w") as j:
|
||||
j.write(json.dumps(launches))
|
||||
except:
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
@ -488,7 +488,7 @@ def async_iteration(iterator):
|
||||
return next(iterator)
|
||||
except StopIteration:
|
||||
# raise a ValueError here because co-routines can't raise StopIteration themselves
|
||||
raise StopAsyncIteration()
|
||||
raise StopAsyncIteration() from None
|
||||
|
||||
|
||||
class AsyncRequest:
|
||||
@ -825,7 +825,7 @@ def get_cancel_function(
|
||||
]
|
||||
|
||||
async def cancel(session_hash: str) -> None:
|
||||
task_ids = set([f"{session_hash}_{fn}" for fn in fn_to_comp])
|
||||
task_ids = {f"{session_hash}_{fn}" for fn in fn_to_comp}
|
||||
await cancel_tasks(task_ids)
|
||||
|
||||
return (
|
||||
|
@ -64,14 +64,21 @@ include = [
|
||||
[tool.ruff]
|
||||
target-version = "py37"
|
||||
extend-select = [
|
||||
"B",
|
||||
"C",
|
||||
"I",
|
||||
# Formatting-related UP rules
|
||||
"UP030",
|
||||
"UP031",
|
||||
"UP032",
|
||||
]
|
||||
ignore = [
|
||||
"C901", # function is too complex (TODO: un-ignore this)
|
||||
"B023", # function definition in loop (TODO: un-ignore this)
|
||||
"B008", # function call in argument defaults
|
||||
"B017", # pytest.raises considered evil
|
||||
"B028", # explicit stacklevel for warnings
|
||||
"E501", # from scripts/lint_backend.sh
|
||||
"E722", # from scripts/lint_backend.sh
|
||||
"E731", # from scripts/lint_backend.sh
|
||||
"F403", # from scripts/lint_backend.sh
|
||||
"F541", # from scripts/lint_backend.sh
|
||||
]
|
||||
|
||||
[tool.ruff.per-file-ignores]
|
||||
|
@ -197,7 +197,7 @@ respx==0.19.2
|
||||
# via -r requirements.in
|
||||
rfc3986[idna2008]==1.5.0
|
||||
# via httpx
|
||||
ruff==0.0.260
|
||||
ruff==0.0.263
|
||||
# via -r requirements.in
|
||||
s3transfer==0.6.0
|
||||
# via boto3
|
||||
|
@ -185,7 +185,7 @@ requests==2.28.1
|
||||
# transformers
|
||||
respx==0.19.2
|
||||
# via -r requirements.in
|
||||
ruff==0.0.260
|
||||
ruff==0.0.263
|
||||
# via -r requirements.in
|
||||
rfc3986[idna2008]==1.5.0
|
||||
# via httpx
|
||||
|
@ -87,10 +87,11 @@ class TestBlocksMethods:
|
||||
def fake_func():
|
||||
return "Hello There"
|
||||
|
||||
xray_model = lambda diseases, img: {
|
||||
disease: random.random() for disease in diseases
|
||||
}
|
||||
ct_model = lambda diseases, img: {disease: 0.1 for disease in diseases}
|
||||
def xray_model(diseases, img):
|
||||
return {disease: random.random() for disease in diseases}
|
||||
|
||||
def ct_model(diseases, img):
|
||||
return {disease: 0.1 for disease in diseases}
|
||||
|
||||
with gr.Blocks() as demo:
|
||||
gr.Markdown(
|
||||
@ -405,7 +406,7 @@ class TestComponentsInBlocks:
|
||||
assert all(dependencies_on_load)
|
||||
assert len(dependencies_on_load) == 2
|
||||
# Queue should be explicitly false for these events
|
||||
assert all([dep["queue"] is False for dep in demo.config["dependencies"]])
|
||||
assert all(dep["queue"] is False for dep in demo.config["dependencies"])
|
||||
|
||||
def test_io_components_attach_load_events_when_value_is_fn(self, io_components):
|
||||
io_components = [comp for comp in io_components if comp not in [gr.State]]
|
||||
@ -419,7 +420,7 @@ class TestComponentsInBlocks:
|
||||
dep for dep in interface.config["dependencies"] if dep["trigger"] == "load"
|
||||
]
|
||||
assert len(dependencies_on_load) == len(io_components)
|
||||
assert all([dep["every"] == 1 for dep in dependencies_on_load])
|
||||
assert all(dep["every"] == 1 for dep in dependencies_on_load)
|
||||
|
||||
def test_get_load_events(self, io_components):
|
||||
components = []
|
||||
@ -451,7 +452,7 @@ class TestBlocksPostprocessing:
|
||||
0, [gr.update(value=None) for _ in io_components], state={}
|
||||
)
|
||||
assert all(
|
||||
[o["value"] == c.postprocess(None) for o, c in zip(output, io_components)]
|
||||
o["value"] == c.postprocess(None) for o, c in zip(output, io_components)
|
||||
)
|
||||
|
||||
def test_blocks_does_not_replace_keyword_literal(self):
|
||||
@ -1213,7 +1214,7 @@ class TestEvery:
|
||||
# If the continuous event got pushed to the queue, the size would be nonzero
|
||||
# asserting false will terminate the test
|
||||
if status.json()["queue_size"] != 0:
|
||||
assert False
|
||||
raise AssertionError()
|
||||
else:
|
||||
break
|
||||
|
||||
@ -1275,7 +1276,7 @@ class TestProgressBar:
|
||||
for _ in prog.tqdm(range(4), unit="iter"):
|
||||
time.sleep(0.25)
|
||||
time.sleep(1)
|
||||
for i in tqdm(["a", "b", "c"], desc="alphabet"):
|
||||
for _ in tqdm(["a", "b", "c"], desc="alphabet"):
|
||||
time.sleep(0.25)
|
||||
return f"Hello, {s}!"
|
||||
|
||||
@ -1331,7 +1332,7 @@ class TestProgressBar:
|
||||
for _ in prog.tqdm(range(4), unit="iter"):
|
||||
time.sleep(0.25)
|
||||
time.sleep(1)
|
||||
for i in tqdm(["a", "b", "c"], desc="alphabet"):
|
||||
for _ in tqdm(["a", "b", "c"], desc="alphabet"):
|
||||
time.sleep(0.25)
|
||||
return f"Hello, {s}!"
|
||||
|
||||
|
@ -1473,7 +1473,7 @@ class TestNames:
|
||||
# This test ensures that `components.get_component_instance()` works correctly when instantiating from components
|
||||
def test_no_duplicate_uncased_names(self):
|
||||
subclasses = gr.components.Component.__subclasses__()
|
||||
unique_subclasses_uncased = set([s.__name__.lower() for s in subclasses])
|
||||
unique_subclasses_uncased = {s.__name__.lower() for s in subclasses}
|
||||
assert len(subclasses) == len(unique_subclasses_uncased)
|
||||
|
||||
|
||||
@ -2152,7 +2152,7 @@ def test_dataset_calls_as_example(*mocks):
|
||||
]
|
||||
],
|
||||
)
|
||||
assert all([m.called for m in mocks])
|
||||
assert all(m.called for m in mocks)
|
||||
|
||||
|
||||
cars = vega_datasets.data.cars()
|
||||
@ -2194,7 +2194,7 @@ class TestScatterPlot:
|
||||
x_title="Horse",
|
||||
)
|
||||
output = plot.postprocess(cars)
|
||||
assert sorted(list(output.keys())) == ["chart", "plot", "type"]
|
||||
assert sorted(output.keys()) == ["chart", "plot", "type"]
|
||||
config = json.loads(output["plot"])
|
||||
assert config["encoding"]["x"]["field"] == "Horsepower"
|
||||
assert config["encoding"]["x"]["title"] == "Horse"
|
||||
@ -2215,7 +2215,7 @@ class TestScatterPlot:
|
||||
x="Horsepower", y="Miles_per_Gallon", tooltip="Name", interactive=False
|
||||
)
|
||||
output = plot.postprocess(cars)
|
||||
assert sorted(list(output.keys())) == ["chart", "plot", "type"]
|
||||
assert sorted(output.keys()) == ["chart", "plot", "type"]
|
||||
config = json.loads(output["plot"])
|
||||
assert "selection" not in config
|
||||
|
||||
@ -2224,7 +2224,7 @@ class TestScatterPlot:
|
||||
x="Horsepower", y="Miles_per_Gallon", height=100, width=200
|
||||
)
|
||||
output = plot.postprocess(cars)
|
||||
assert sorted(list(output.keys())) == ["chart", "plot", "type"]
|
||||
assert sorted(output.keys()) == ["chart", "plot", "type"]
|
||||
config = json.loads(output["plot"])
|
||||
assert config["height"] == 100
|
||||
assert config["width"] == 200
|
||||
@ -2379,7 +2379,7 @@ class TestLinePlot:
|
||||
x_title="Trading Day",
|
||||
)
|
||||
output = plot.postprocess(stocks)
|
||||
assert sorted(list(output.keys())) == ["chart", "plot", "type"]
|
||||
assert sorted(output.keys()) == ["chart", "plot", "type"]
|
||||
config = json.loads(output["plot"])
|
||||
for layer in config["layer"]:
|
||||
assert layer["mark"]["type"] in ["line", "point"]
|
||||
@ -2394,7 +2394,7 @@ class TestLinePlot:
|
||||
def test_height_width(self):
|
||||
plot = gr.LinePlot(x="date", y="price", height=100, width=200)
|
||||
output = plot.postprocess(stocks)
|
||||
assert sorted(list(output.keys())) == ["chart", "plot", "type"]
|
||||
assert sorted(output.keys()) == ["chart", "plot", "type"]
|
||||
config = json.loads(output["plot"])
|
||||
assert config["height"] == 100
|
||||
assert config["width"] == 200
|
||||
@ -2543,7 +2543,7 @@ class TestBarPlot:
|
||||
x_title="Variable A",
|
||||
)
|
||||
output = plot.postprocess(simple)
|
||||
assert sorted(list(output.keys())) == ["chart", "plot", "type"]
|
||||
assert sorted(output.keys()) == ["chart", "plot", "type"]
|
||||
assert output["chart"] == "bar"
|
||||
config = json.loads(output["plot"])
|
||||
assert config["encoding"]["x"]["field"] == "a"
|
||||
@ -2558,7 +2558,7 @@ class TestBarPlot:
|
||||
def test_height_width(self):
|
||||
plot = gr.BarPlot(x="a", y="b", height=100, width=200)
|
||||
output = plot.postprocess(simple)
|
||||
assert sorted(list(output.keys())) == ["chart", "plot", "type"]
|
||||
assert sorted(output.keys()) == ["chart", "plot", "type"]
|
||||
config = json.loads(output["plot"])
|
||||
assert config["height"] == 100
|
||||
assert config["width"] == 200
|
||||
|
@ -265,7 +265,7 @@ class TestLoadInterface:
|
||||
):
|
||||
pass
|
||||
else:
|
||||
assert False
|
||||
raise AssertionError()
|
||||
else:
|
||||
assert resp.json()["data"] is not None
|
||||
finally:
|
||||
@ -345,11 +345,8 @@ class TestLoadInterfaceWithExamples:
|
||||
def test_root_url(self):
|
||||
demo = gr.load("spaces/gradio/test-loading-examples")
|
||||
assert all(
|
||||
[
|
||||
c["props"]["root_url"]
|
||||
== "https://gradio-test-loading-examples.hf.space/"
|
||||
for c in demo.get_config_file()["components"]
|
||||
]
|
||||
c["props"]["root_url"] == "https://gradio-test-loading-examples.hf.space/"
|
||||
for c in demo.get_config_file()["components"]
|
||||
)
|
||||
|
||||
def test_root_url_deserialization(self):
|
||||
@ -427,7 +424,7 @@ def check_dataframe(config):
|
||||
def check_dataset(config, readme_examples):
|
||||
# No Examples
|
||||
if not any(readme_examples.values()):
|
||||
assert not any([c for c in config["components"] if c["type"] == "dataset"])
|
||||
assert not any(c for c in config["components"] if c["type"] == "dataset")
|
||||
else:
|
||||
dataset = next(c for c in config["components"] if c["type"] == "dataset")
|
||||
assert dataset["props"]["samples"] == [[cols_to_rows(readme_examples)[1]]]
|
||||
|
@ -82,7 +82,7 @@ class TestInterface:
|
||||
)
|
||||
interface = Interface(lambda x: 3 * x, "number", "number", examples=path)
|
||||
dataset_check = any(
|
||||
[c["type"] == "dataset" for c in interface.get_config_file()["components"]]
|
||||
c["type"] == "dataset" for c in interface.get_config_file()["components"]
|
||||
)
|
||||
assert dataset_check
|
||||
|
||||
@ -117,7 +117,9 @@ class TestInterface:
|
||||
interface.close()
|
||||
|
||||
def test_interface_representation(self):
|
||||
prediction_fn = lambda x: x
|
||||
def prediction_fn(x):
|
||||
return x
|
||||
|
||||
prediction_fn.__name__ = "prediction_fn"
|
||||
repr = str(Interface(prediction_fn, "textbox", "label")).split("\n")
|
||||
assert prediction_fn.__name__ in repr[0]
|
||||
|
@ -12,10 +12,13 @@ from gradio.processing_utils import decode_base64_to_image
|
||||
os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"
|
||||
|
||||
|
||||
def max_word_len(text: str) -> int:
|
||||
return max([len(word) for word in text.split(" ")])
|
||||
|
||||
|
||||
class TestDefault:
|
||||
@pytest.mark.asyncio
|
||||
async def test_default_text(self):
|
||||
max_word_len = lambda text: max([len(word) for word in text.split(" ")])
|
||||
text_interface = Interface(
|
||||
max_word_len, "textbox", "label", interpretation="default"
|
||||
)
|
||||
@ -29,7 +32,6 @@ class TestDefault:
|
||||
class TestShapley:
|
||||
@pytest.mark.asyncio
|
||||
async def test_shapley_text(self):
|
||||
max_word_len = lambda text: max([len(word) for word in text.split(" ")])
|
||||
text_interface = Interface(
|
||||
max_word_len, "textbox", "label", interpretation="shapley"
|
||||
)
|
||||
@ -42,8 +44,9 @@ class TestShapley:
|
||||
class TestCustom:
|
||||
@pytest.mark.asyncio
|
||||
async def test_custom_text(self):
|
||||
max_word_len = lambda text: max([len(word) for word in text.split(" ")])
|
||||
custom = lambda text: [(char, 1) for char in text]
|
||||
def custom(text):
|
||||
return [(char, 1) for char in text]
|
||||
|
||||
text_interface = Interface(
|
||||
max_word_len, "textbox", "label", interpretation=custom
|
||||
)
|
||||
@ -54,8 +57,12 @@ class TestCustom:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_custom_img(self):
|
||||
max_pixel_value = lambda img: img.max()
|
||||
custom = lambda img: img.tolist()
|
||||
def max_pixel_value(img):
|
||||
return img.max()
|
||||
|
||||
def custom(img):
|
||||
return img.tolist()
|
||||
|
||||
img_interface = Interface(
|
||||
max_pixel_value, "image", "label", interpretation=custom
|
||||
)
|
||||
|
@ -265,8 +265,8 @@ class TestQueueProcessEvents:
|
||||
# setting up the function to expect further iterative responses.
|
||||
# Then we provide a 500 response.
|
||||
side_effects = [
|
||||
MagicMock(has_exception=False, status=200, json=dict(is_generating=True)),
|
||||
MagicMock(has_exception=False, status=500, json=dict(error="Foo")),
|
||||
MagicMock(has_exception=False, status=200, json={"is_generating": True}),
|
||||
MagicMock(has_exception=False, status=500, json={"error": "Foo"}),
|
||||
]
|
||||
mock_event.disconnect = AsyncMock()
|
||||
queue.gather_event_data = AsyncMock(return_value=True)
|
||||
@ -301,7 +301,7 @@ class TestQueueProcessEvents:
|
||||
ValueError("Can't connect"),
|
||||
]
|
||||
queue.call_prediction = AsyncMock(
|
||||
return_value=MagicMock(has_exception=False, json=dict(is_generating=False))
|
||||
return_value=MagicMock(has_exception=False, json={"is_generating": False})
|
||||
)
|
||||
mock_event.disconnect = AsyncMock()
|
||||
queue.clean_event = AsyncMock()
|
||||
@ -326,7 +326,7 @@ class TestQueueProcessEvents:
|
||||
mock_event.websocket.receive_json.return_value = {"data": ["test"], "fn": 0}
|
||||
mock_event.websocket.send_json = AsyncMock()
|
||||
queue.call_prediction = AsyncMock(
|
||||
return_value=MagicMock(has_exception=False, json=dict(is_generating=False))
|
||||
return_value=MagicMock(has_exception=False, json={"is_generating": False})
|
||||
)
|
||||
# No exception should be raised during `process_event`
|
||||
mock_event.disconnect = AsyncMock(side_effect=ValueError("..."))
|
||||
@ -373,7 +373,7 @@ class TestQueueBatch:
|
||||
mock_event.disconnect.assert_called_once()
|
||||
|
||||
mock_event2.disconnect.assert_called_once()
|
||||
queue.clean_event.call_count == 2
|
||||
assert queue.clean_event.call_count == 2
|
||||
|
||||
|
||||
class TestGetEventsInBatch:
|
||||
|
@ -390,12 +390,12 @@ class TestAuthenticatedRoutes:
|
||||
|
||||
response = client.post(
|
||||
"/login",
|
||||
data=dict(username="test", password="correct_password"),
|
||||
data={"username": "test", "password": "correct_password"},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
response = client.post(
|
||||
"/login",
|
||||
data=dict(username="test", password="incorrect_password"),
|
||||
data={"username": "test", "password": "incorrect_password"},
|
||||
)
|
||||
assert response.status_code == 400
|
||||
|
||||
@ -524,7 +524,7 @@ class TestPassingRequest:
|
||||
|
||||
client.post(
|
||||
"/login",
|
||||
data=dict(username="admin", password="password"),
|
||||
data={"username": "admin", "password": "password"},
|
||||
)
|
||||
response = client.post("/api/predict/", json={"data": ["test"]})
|
||||
assert response.status_code == 200
|
||||
|
Loading…
x
Reference in New Issue
Block a user