mirror of
https://github.com/gradio-app/gradio.git
synced 2025-01-30 11:00:11 +08:00
Assert refactor in external.py (#5811)
* Refactored assert statements to if statements * format-addons * format * add changeset * Update gradio/external.py Co-authored-by: Abubakar Abid <abubakar@huggingface.co> * refactored video.py * refactored all the assert statements with response * add changeset * add changeset * Apply suggestions from code review * Refactored documentation.py and few more files * avoid circular * Replaced all assert statements * lint * notebooks * fix * minor changes * final changes according to tests * Lint * last fix * fix * fix utils test * fix serialization error * fix serialization error --------- Co-authored-by: harry-urek <hariombhardwaj038@gmail.com> Co-authored-by: gradio-pr-bot <gradio-pr-bot@users.noreply.github.com> Co-authored-by: Abubakar Abid <abubakar@huggingface.co>
This commit is contained in:
parent
48e09ee887
commit
1d5b15a2d2
6
.changeset/light-buses-enter.md
Normal file
6
.changeset/light-buses-enter.md
Normal file
@ -0,0 +1,6 @@
|
||||
---
|
||||
"gradio": patch
|
||||
"gradio_client": patch
|
||||
---
|
||||
|
||||
fix:Assert refactor in external.py
|
@ -31,6 +31,7 @@ from packaging import version
|
||||
|
||||
from gradio_client import serializing, utils
|
||||
from gradio_client.documentation import document, set_documentation_group
|
||||
from gradio_client.exceptions import SerializationSetupError
|
||||
from gradio_client.serializing import Serializable
|
||||
from gradio_client.utils import (
|
||||
Communicator,
|
||||
@ -646,9 +647,8 @@ class Client:
|
||||
raise ValueError(
|
||||
f"Each entry in api_names must be either a string or a tuple of strings. Received {api_names}"
|
||||
)
|
||||
assert (
|
||||
len(api_names) == 1
|
||||
), "Currently only one api_name can be deployed to discord."
|
||||
if len(api_names) != 1:
|
||||
raise ValueError("Currently only one api_name can be deployed to discord.")
|
||||
|
||||
for i, name in enumerate(api_names):
|
||||
if isinstance(name, str):
|
||||
@ -676,8 +676,8 @@ class Client:
|
||||
is_private = False
|
||||
if self.space_id:
|
||||
is_private = huggingface_hub.space_info(self.space_id).private
|
||||
if is_private:
|
||||
assert hf_token, (
|
||||
if is_private and not hf_token:
|
||||
raise ValueError(
|
||||
f"Since {self.space_id} is private, you must explicitly pass in hf_token "
|
||||
"so that it can be added as a secret in the discord bot space."
|
||||
)
|
||||
@ -777,7 +777,7 @@ class Endpoint:
|
||||
# and api_name is not False (meaning that the developer has explicitly disabled the API endpoint)
|
||||
self.serializers, self.deserializers = self._setup_serializers()
|
||||
self.is_valid = self.dependency["backend_fn"] and self.api_name is not False
|
||||
except AssertionError:
|
||||
except SerializationSetupError:
|
||||
self.is_valid = False
|
||||
|
||||
def __repr__(self):
|
||||
@ -952,9 +952,10 @@ class Endpoint:
|
||||
return data
|
||||
|
||||
def serialize(self, *data) -> tuple:
|
||||
assert len(data) == len(
|
||||
self.serializers
|
||||
), f"Expected {len(self.serializers)} arguments, got {len(data)}"
|
||||
if len(data) != len(self.serializers):
|
||||
raise ValueError(
|
||||
f"Expected {len(self.serializers)} arguments, got {len(data)}"
|
||||
)
|
||||
|
||||
files = [
|
||||
f
|
||||
@ -968,9 +969,10 @@ class Endpoint:
|
||||
return o
|
||||
|
||||
def deserialize(self, *data) -> tuple:
|
||||
assert len(data) == len(
|
||||
self.deserializers
|
||||
), f"Expected {len(self.deserializers)} outputs, got {len(data)}"
|
||||
if len(data) != len(self.deserializers):
|
||||
raise ValueError(
|
||||
f"Expected {len(self.deserializers)} outputs, got {len(data)}"
|
||||
)
|
||||
outputs = tuple(
|
||||
[
|
||||
s.deserialize(
|
||||
@ -1002,15 +1004,17 @@ class Endpoint:
|
||||
self.input_component_types.append(component_name)
|
||||
if component.get("serializer"):
|
||||
serializer_name = component["serializer"]
|
||||
assert (
|
||||
serializer_name in serializing.SERIALIZER_MAPPING
|
||||
), f"Unknown serializer: {serializer_name}, you may need to update your gradio_client version."
|
||||
if serializer_name not in serializing.SERIALIZER_MAPPING:
|
||||
raise SerializationSetupError(
|
||||
f"Unknown serializer: {serializer_name}, you may need to update your gradio_client version."
|
||||
)
|
||||
serializer = serializing.SERIALIZER_MAPPING[serializer_name]
|
||||
else:
|
||||
assert (
|
||||
component_name in serializing.COMPONENT_MAPPING
|
||||
), f"Unknown component: {component_name}, you may need to update your gradio_client version."
|
||||
elif component_name in serializing.COMPONENT_MAPPING:
|
||||
serializer = serializing.COMPONENT_MAPPING[component_name]
|
||||
else:
|
||||
raise SerializationSetupError(
|
||||
f"Unknown component: {component_name}, you may need to update your gradio_client version."
|
||||
)
|
||||
serializers.append(serializer()) # type: ignore
|
||||
|
||||
outputs = self.dependency["outputs"]
|
||||
@ -1022,17 +1026,19 @@ class Endpoint:
|
||||
self.output_component_types.append(component_name)
|
||||
if component.get("serializer"):
|
||||
serializer_name = component["serializer"]
|
||||
assert (
|
||||
serializer_name in serializing.SERIALIZER_MAPPING
|
||||
), f"Unknown serializer: {serializer_name}, you may need to update your gradio_client version."
|
||||
if serializer_name not in serializing.SERIALIZER_MAPPING:
|
||||
raise SerializationSetupError(
|
||||
f"Unknown serializer: {serializer_name}, you may need to update your gradio_client version."
|
||||
)
|
||||
deserializer = serializing.SERIALIZER_MAPPING[serializer_name]
|
||||
elif component_name in utils.SKIP_COMPONENTS:
|
||||
deserializer = serializing.SimpleSerializable
|
||||
else:
|
||||
assert (
|
||||
component_name in serializing.COMPONENT_MAPPING
|
||||
), f"Unknown component: {component_name}, you may need to update your gradio_client version."
|
||||
elif component_name in serializing.COMPONENT_MAPPING:
|
||||
deserializer = serializing.COMPONENT_MAPPING[component_name]
|
||||
else:
|
||||
raise SerializationSetupError(
|
||||
f"Unknown component: {component_name}, you may need to update your gradio_client version."
|
||||
)
|
||||
deserializers.append(deserializer()) # type: ignore
|
||||
|
||||
return serializers, deserializers
|
||||
|
@ -26,16 +26,18 @@ def extract_instance_attr_doc(cls, attr):
|
||||
"self." + attr + " ="
|
||||
):
|
||||
break
|
||||
assert i is not None, f"Could not find {attr} in {cls.__name__}"
|
||||
if i is None:
|
||||
raise NameError(f"Could not find {attr} in {cls.__name__}")
|
||||
start_line = lines.index('"""', i)
|
||||
end_line = lines.index('"""', start_line + 1)
|
||||
for j in range(i + 1, start_line):
|
||||
assert not lines[j].startswith("self."), (
|
||||
f"Found another attribute before docstring for {attr} in {cls.__name__}: "
|
||||
+ lines[j]
|
||||
+ "\n start:"
|
||||
+ lines[i]
|
||||
)
|
||||
if lines[j].startswith("self."):
|
||||
raise ValueError(
|
||||
f"Found another attribute before docstring for {attr} in {cls.__name__}: "
|
||||
+ lines[j]
|
||||
+ "\n start:"
|
||||
+ lines[i]
|
||||
)
|
||||
doc_string = " ".join(lines[start_line + 1 : end_line])
|
||||
return doc_string
|
||||
|
||||
@ -95,15 +97,17 @@ def document_fn(fn: Callable, cls) -> tuple[str, list[dict], dict, str | None]:
|
||||
continue
|
||||
if not (line.startswith(" ") or line.strip() == ""):
|
||||
print(line)
|
||||
assert (
|
||||
line.startswith(" ") or line.strip() == ""
|
||||
), f"Documentation format for {fn.__name__} has format error in line: {line}"
|
||||
if not (line.startswith(" ") or line.strip() == ""):
|
||||
raise SyntaxError(
|
||||
f"Documentation format for {fn.__name__} has format error in line: {line}"
|
||||
)
|
||||
line = line[4:]
|
||||
if mode == "parameter":
|
||||
colon_index = line.index(": ")
|
||||
assert (
|
||||
colon_index > -1
|
||||
), f"Documentation format for {fn.__name__} has format error in line: {line}"
|
||||
if colon_index < -1:
|
||||
raise SyntaxError(
|
||||
f"Documentation format for {fn.__name__} has format error in line: {line}"
|
||||
)
|
||||
parameter = line[:colon_index]
|
||||
parameter_doc = line[colon_index + 2 :]
|
||||
parameters[parameter] = parameter_doc
|
||||
@ -172,9 +176,10 @@ def document_cls(cls):
|
||||
if mode == "description":
|
||||
description_lines.append(line if line.strip() else "<br>")
|
||||
else:
|
||||
assert (
|
||||
line.startswith(" ") or not line.strip()
|
||||
), f"Documentation format for {cls.__name__} has format error in line: {line}"
|
||||
if not (line.startswith(" ") or not line.strip()):
|
||||
raise SyntaxError(
|
||||
f"Documentation format for {cls.__name__} has format error in line: {line}"
|
||||
)
|
||||
tags[mode].append(line[4:])
|
||||
if "example" in tags:
|
||||
example = "\n".join(tags["example"])
|
||||
|
4
client/python/gradio_client/exceptions.py
Normal file
4
client/python/gradio_client/exceptions.py
Normal file
@ -0,0 +1,4 @@
|
||||
class SerializationSetupError(ValueError):
|
||||
"""Raised when a serializers cannot be set up correctly."""
|
||||
|
||||
pass
|
@ -307,7 +307,8 @@ class FileSerializable(Serializable):
|
||||
elif isinstance(x, dict):
|
||||
if x.get("is_file"):
|
||||
filepath = x.get("name")
|
||||
assert filepath is not None, f"The 'name' field is missing in {x}"
|
||||
if filepath is None:
|
||||
raise ValueError(f"The 'name' field is missing in {x}")
|
||||
if root_url is not None:
|
||||
file_name = utils.download_tmp_copy_of_file(
|
||||
root_url + "file=" + filepath,
|
||||
@ -331,7 +332,8 @@ class FileSerializable(Serializable):
|
||||
file_name = str(path)
|
||||
else:
|
||||
data = x.get("data")
|
||||
assert data is not None, f"The 'data' field is missing in {x}"
|
||||
if data is None:
|
||||
raise ValueError(f"The 'data' field is missing in {x}")
|
||||
file_name = utils.decode_base64_to_file(data, dir=save_dir).name
|
||||
else:
|
||||
raise ValueError(
|
||||
@ -426,7 +428,8 @@ class VideoSerializable(FileSerializable):
|
||||
version (string filepath). Optionally, save the file to the directory specified by `save_dir`
|
||||
"""
|
||||
if isinstance(x, (tuple, list)):
|
||||
assert len(x) == 2, f"Expected tuple of length 2. Received: {x}"
|
||||
if len(x) != 2:
|
||||
raise ValueError(f"Expected tuple of length 2. Received: {x}")
|
||||
x_as_list = [x[0], x[1]]
|
||||
else:
|
||||
raise ValueError(f"Expected tuple of length 2. Received: {x}")
|
||||
|
File diff suppressed because one or more lines are too long
@ -20,7 +20,8 @@ FIGSIZE = 7, 7 # does not affect size in webpage
|
||||
COLORS = [
|
||||
'blue', 'orange', 'green', 'red', 'purple', 'brown', 'pink', 'gray', 'olive', 'cyan'
|
||||
]
|
||||
assert len(COLORS) >= MAX_CLUSTERS, "Not enough different colors for all clusters"
|
||||
if len(COLORS) <= MAX_CLUSTERS:
|
||||
raise ValueError("Not enough different colors for all clusters")
|
||||
np.random.seed(SEED)
|
||||
|
||||
|
||||
|
@ -44,6 +44,7 @@ from gradio.exceptions import (
|
||||
DuplicateBlockError,
|
||||
InvalidApiNameError,
|
||||
InvalidBlockError,
|
||||
InvalidComponentError,
|
||||
)
|
||||
from gradio.helpers import EventData, create_tracker, skip, special_args
|
||||
from gradio.state_holder import SessionState
|
||||
@ -364,9 +365,10 @@ def postprocess_update_dict(block: Block, update_dict: dict, postprocess: bool =
|
||||
attr_dict["__type__"] = "update"
|
||||
attr_dict.pop("value", None)
|
||||
if "value" in update_dict:
|
||||
assert isinstance(
|
||||
block, components.IOComponent
|
||||
), f"Component {block.__class__} does not support value"
|
||||
if not isinstance(block, components.IOComponent):
|
||||
raise InvalidComponentError(
|
||||
f"Component {block.__class__} does not support value"
|
||||
)
|
||||
if postprocess:
|
||||
attr_dict["value"] = block.postprocess(update_dict["value"])
|
||||
else:
|
||||
@ -766,9 +768,10 @@ class Blocks(BlockContext):
|
||||
|
||||
children = child_config.get("children")
|
||||
if children is not None:
|
||||
assert isinstance(
|
||||
block, BlockContext
|
||||
), f"Invalid config, Block with id {id} has children but is not a BlockContext."
|
||||
if not isinstance(block, BlockContext):
|
||||
raise ValueError(
|
||||
f"Invalid config, Block with id {id} has children but is not a BlockContext."
|
||||
)
|
||||
with block:
|
||||
iterate_over_children(children)
|
||||
|
||||
@ -1158,7 +1161,8 @@ class Blocks(BlockContext):
|
||||
event_data: data associated with event trigger
|
||||
"""
|
||||
block_fn = self.fns[fn_index]
|
||||
assert block_fn.fn, f"function with index {fn_index} not defined."
|
||||
if not block_fn.fn:
|
||||
raise IndexError(f"function with index {fn_index} not defined.")
|
||||
is_generating = False
|
||||
request = requests[0] if isinstance(requests, list) else requests
|
||||
start = time.time()
|
||||
@ -1234,9 +1238,10 @@ class Blocks(BlockContext):
|
||||
raise InvalidBlockError(
|
||||
f"Input component with id {input_id} used in {dependency['trigger']}() event is not defined in this gr.Blocks context. You are allowed to nest gr.Blocks contexts, but there must be a gr.Blocks context that contains all components and events."
|
||||
) from e
|
||||
assert isinstance(
|
||||
block, components.IOComponent
|
||||
), f"{block.__class__} Component with id {input_id} not a valid input component."
|
||||
if not isinstance(block, components.IOComponent):
|
||||
raise InvalidComponentError(
|
||||
f"{block.__class__} Component with id {input_id} not a valid input component."
|
||||
)
|
||||
serialized_input = block.serialize(inputs[i])
|
||||
processed_input.append(serialized_input)
|
||||
|
||||
@ -1253,9 +1258,10 @@ class Blocks(BlockContext):
|
||||
raise InvalidBlockError(
|
||||
f"Output component with id {output_id} used in {dependency['trigger']}() event not found in this gr.Blocks context. You are allowed to nest gr.Blocks contexts, but there must be a gr.Blocks context that contains all components and events."
|
||||
) from e
|
||||
assert isinstance(
|
||||
block, components.IOComponent
|
||||
), f"{block.__class__} Component with id {output_id} not a valid output component."
|
||||
if not isinstance(block, components.IOComponent):
|
||||
raise InvalidComponentError(
|
||||
f"{block.__class__} Component with id {output_id} not a valid output component."
|
||||
)
|
||||
deserialized = block.deserialize(
|
||||
outputs[o],
|
||||
save_dir=block.DEFAULT_TEMP_DIR,
|
||||
@ -1322,9 +1328,10 @@ Received inputs:
|
||||
raise InvalidBlockError(
|
||||
f"Input component with id {input_id} used in {dependency['trigger']}() event not found in this gr.Blocks context. You are allowed to nest gr.Blocks contexts, but there must be a gr.Blocks context that contains all components and events."
|
||||
) from e
|
||||
assert isinstance(
|
||||
block, components.Component
|
||||
), f"{block.__class__} Component with id {input_id} not a valid input component."
|
||||
if not isinstance(block, components.Component):
|
||||
raise InvalidComponentError(
|
||||
f"{block.__class__} Component with id {input_id} not a valid input component."
|
||||
)
|
||||
if getattr(block, "stateful", False):
|
||||
processed_input.append(state[input_id])
|
||||
else:
|
||||
@ -1445,9 +1452,10 @@ Received outputs:
|
||||
postprocess=block_fn.postprocess,
|
||||
)
|
||||
elif block_fn.postprocess:
|
||||
assert isinstance(
|
||||
block, components.Component
|
||||
), f"{block.__class__} Component with id {output_id} not a valid output component."
|
||||
if not isinstance(block, components.Component):
|
||||
raise InvalidComponentError(
|
||||
f"{block.__class__} Component with id {output_id} not a valid output component."
|
||||
)
|
||||
prediction_value = block.postprocess(prediction_value)
|
||||
output.append(prediction_value)
|
||||
|
||||
@ -2005,9 +2013,8 @@ Received outputs:
|
||||
)
|
||||
|
||||
if self.is_running:
|
||||
assert isinstance(
|
||||
self.local_url, str
|
||||
), f"Invalid local_url: {self.local_url}"
|
||||
if not isinstance(self.local_url, str):
|
||||
raise ValueError(f"Invalid local_url: {self.local_url}")
|
||||
if not (quiet):
|
||||
print(
|
||||
"Rerunning server... use `close()` to stop if you need to change `launch()` parameters.\n----"
|
||||
|
@ -205,12 +205,14 @@ class Chatbot(Changeable, Selectable, Likeable, IOComponent, JSONSerializable):
|
||||
return y
|
||||
processed_messages = []
|
||||
for message_pair in y:
|
||||
assert isinstance(
|
||||
message_pair, (tuple, list)
|
||||
), f"Expected a list of lists or list of tuples. Received: {message_pair}"
|
||||
assert (
|
||||
len(message_pair) == 2
|
||||
), f"Expected a list of lists of length 2 or list of tuples of length 2. Received: {message_pair}"
|
||||
if not isinstance(message_pair, (tuple, list)):
|
||||
raise TypeError(
|
||||
f"Expected a list of lists or list of tuples. Received: {message_pair}"
|
||||
)
|
||||
if len(message_pair) != 2:
|
||||
raise TypeError(
|
||||
f"Expected a list of lists of length 2 or list of tuples of length 2. Received: {message_pair}"
|
||||
)
|
||||
processed_messages.append(
|
||||
[
|
||||
self._preprocess_chat_messages(message_pair[0]),
|
||||
@ -259,12 +261,14 @@ class Chatbot(Changeable, Selectable, Likeable, IOComponent, JSONSerializable):
|
||||
return []
|
||||
processed_messages = []
|
||||
for message_pair in y:
|
||||
assert isinstance(
|
||||
message_pair, (tuple, list)
|
||||
), f"Expected a list of lists or list of tuples. Received: {message_pair}"
|
||||
assert (
|
||||
len(message_pair) == 2
|
||||
), f"Expected a list of lists of length 2 or list of tuples of length 2. Received: {message_pair}"
|
||||
if not isinstance(message_pair, (tuple, list)):
|
||||
raise TypeError(
|
||||
f"Expected a list of lists or list of tuples. Received: {message_pair}"
|
||||
)
|
||||
if len(message_pair) != 2:
|
||||
raise TypeError(
|
||||
f"Expected a list of lists of length 2 or list of tuples of length 2. Received: {message_pair}"
|
||||
)
|
||||
processed_messages.append(
|
||||
[
|
||||
self._postprocess_chat_messages(message_pair[0]),
|
||||
|
@ -81,7 +81,9 @@ class Code(Changeable, Inputable, IOComponent, StringSerializable):
|
||||
elem_id: An optional string that is assigned as the id of this component in the HTML DOM. Can be used for targeting CSS styles.
|
||||
elem_classes: An optional list of strings that are assigned as the classes of this component in the HTML DOM. Can be used for targeting CSS styles.
|
||||
"""
|
||||
assert language in Code.languages, f"Language {language} not supported."
|
||||
if language not in Code.languages:
|
||||
raise ValueError(f"Language {language} not supported.")
|
||||
|
||||
self.language = language
|
||||
self.lines = lines
|
||||
IOComponent.__init__(
|
||||
|
@ -280,7 +280,8 @@ class Dataframe(Changeable, Inputable, Selectable, IOComponent, JSONSerializable
|
||||
return self.postprocess([[]])
|
||||
if isinstance(y, np.ndarray):
|
||||
y = y.tolist()
|
||||
assert isinstance(y, list), "output cannot be converted to list"
|
||||
if not isinstance(y, list):
|
||||
raise ValueError("output cannot be converted to list")
|
||||
|
||||
_headers = self.headers
|
||||
if len(self.headers) < len(y[0]):
|
||||
|
@ -68,9 +68,10 @@ class Dataset(Clickable, Selectable, Component, StringSerializable):
|
||||
self._components = [get_component_instance(c) for c in components]
|
||||
|
||||
# Narrow type to IOComponent
|
||||
assert all(
|
||||
isinstance(c, IOComponent) for c in self._components
|
||||
), "All components in a `Dataset` must be subclasses of `IOComponent`"
|
||||
if not all(isinstance(c, IOComponent) for c in self._components):
|
||||
raise ValueError(
|
||||
"All components in a `Dataset` must be subclasses of `IOComponent`"
|
||||
)
|
||||
self._components = [c for c in self._components if isinstance(c, IOComponent)]
|
||||
for component in self._components:
|
||||
component.root_url = self.root_url
|
||||
|
@ -193,14 +193,16 @@ class Video(
|
||||
)
|
||||
|
||||
if is_file:
|
||||
assert file_name is not None, "Received file data without a file name."
|
||||
if file_name is None:
|
||||
raise ValueError("Received file data without a file name.")
|
||||
if client_utils.is_http_url_like(file_name):
|
||||
fn = self.download_temp_copy_if_needed
|
||||
else:
|
||||
fn = self.make_temp_copy_if_needed
|
||||
file_name = Path(fn(file_name))
|
||||
else:
|
||||
assert file_data is not None, "Received empty file data."
|
||||
if file_data is None:
|
||||
raise ValueError("Received empty file data.")
|
||||
file_name = Path(self.base64_to_temp_file_if_needed(file_data, file_name))
|
||||
|
||||
uploaded_format = file_name.suffix.replace(".", "")
|
||||
@ -270,12 +272,15 @@ class Video(
|
||||
if isinstance(y, (str, Path)):
|
||||
processed_files = (self._format_video(y), None)
|
||||
elif isinstance(y, (tuple, list)):
|
||||
assert (
|
||||
len(y) == 2
|
||||
), f"Expected lists of length 2 or tuples of length 2. Received: {y}"
|
||||
assert isinstance(y[0], (str, Path)) and isinstance(
|
||||
y[1], (str, Path)
|
||||
), f"If a tuple is provided, both elements must be strings or Path objects. Received: {y}"
|
||||
if len(y) != 2:
|
||||
raise ValueError(
|
||||
f"Expected lists of length 2 or tuples of length 2. Received: {y}"
|
||||
)
|
||||
|
||||
if not (isinstance(y[0], (str, Path)) and isinstance(y[1], (str, Path))):
|
||||
raise TypeError(
|
||||
f"If a tuple is provided, both elements must be strings or Path objects. Received: {y}"
|
||||
)
|
||||
video = y[0]
|
||||
subtitle = y[1]
|
||||
processed_files = (
|
||||
|
@ -9,12 +9,30 @@ class DuplicateBlockError(ValueError):
|
||||
pass
|
||||
|
||||
|
||||
class InvalidComponentError(ValueError):
|
||||
"""Raised when invalid components are used."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class TooManyRequestsError(Exception):
|
||||
"""Raised when the Hugging Face API returns a 429 status code."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class ModelNotFoundError(Exception):
|
||||
"""Raised when the provided model doesn't exists or is not found by the provided api url."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class RenderError(Exception):
|
||||
"""Raised when a component has not been rendered in the current Blocks but is expected to have been rendered."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class InvalidApiNameError(ValueError):
|
||||
pass
|
||||
|
||||
|
@ -16,7 +16,7 @@ import gradio
|
||||
from gradio import components, utils
|
||||
from gradio.context import Context
|
||||
from gradio.deprecation import warn_deprecation
|
||||
from gradio.exceptions import Error, TooManyRequestsError
|
||||
from gradio.exceptions import Error, ModelNotFoundError, TooManyRequestsError
|
||||
from gradio.external_utils import (
|
||||
cols_to_rows,
|
||||
encode_to_base64,
|
||||
@ -83,9 +83,10 @@ def load_blocks_from_repo(
|
||||
if src is None:
|
||||
# Separate the repo type (e.g. "model") from repo name (e.g. "google/vit-base-patch16-224")
|
||||
tokens = name.split("/")
|
||||
assert (
|
||||
len(tokens) > 1
|
||||
), "Either `src` parameter must be provided, or `name` must be formatted as {src}/{repo name}"
|
||||
if len(tokens) <= 1:
|
||||
raise ValueError(
|
||||
"Either `src` parameter must be provided, or `name` must be formatted as {src}/{repo name}"
|
||||
)
|
||||
src = tokens[0]
|
||||
name = "/".join(tokens[1:])
|
||||
|
||||
@ -95,9 +96,8 @@ def load_blocks_from_repo(
|
||||
"models": from_model,
|
||||
"spaces": from_spaces,
|
||||
}
|
||||
assert (
|
||||
src.lower() in factory_methods
|
||||
), f"parameter: src must be one of {factory_methods.keys()}"
|
||||
if src.lower() not in factory_methods:
|
||||
raise ValueError(f"parameter: src must be one of {factory_methods.keys()}")
|
||||
|
||||
if hf_token is not None:
|
||||
if Context.hf_token is not None and Context.hf_token != hf_token:
|
||||
@ -145,9 +145,10 @@ def from_model(model_name: str, hf_token: str | None, alias: str | None, **kwarg
|
||||
|
||||
# Checking if model exists, and if so, it gets the pipeline
|
||||
response = requests.request("GET", api_url, headers=headers)
|
||||
assert (
|
||||
response.status_code == 200
|
||||
), f"Could not find model: {model_name}. If it is a private or gated model, please provide your Hugging Face access token (https://huggingface.co/settings/tokens) as the argument for the `api_key` parameter."
|
||||
if response.status_code != 200:
|
||||
raise ModelNotFoundError(
|
||||
f"Could not find model: {model_name}. If it is a private or gated model, please provide your Hugging Face access token (https://huggingface.co/settings/tokens) as the argument for the `api_key` parameter."
|
||||
)
|
||||
p = response.json().get("pipeline_tag")
|
||||
pipelines = {
|
||||
"audio-classification": {
|
||||
|
@ -525,7 +525,8 @@ class Progress(Iterable):
|
||||
):
|
||||
current_iterable = self.iterables.pop()
|
||||
callback(self.iterables)
|
||||
assert current_iterable.index is not None, "Index not set."
|
||||
if current_iterable.index is None:
|
||||
raise IndexError("Index not set.")
|
||||
current_iterable.index += 1
|
||||
try:
|
||||
return next(current_iterable.iterable) # type: ignore
|
||||
@ -603,7 +604,8 @@ class Progress(Iterable):
|
||||
callback = self._progress_callback()
|
||||
if callback and len(self.iterables) > 0:
|
||||
current_iterable = self.iterables[-1]
|
||||
assert current_iterable.index is not None, "Index not set."
|
||||
if current_iterable.index is None:
|
||||
raise IndexError("Index not set.")
|
||||
current_iterable.index += n
|
||||
callback(self.iterables)
|
||||
else:
|
||||
|
@ -28,6 +28,7 @@ from gradio.components import (
|
||||
from gradio.data_classes import InterfaceTypes
|
||||
from gradio.deprecation import warn_deprecation
|
||||
from gradio.events import Changeable, Streamable, Submittable, on
|
||||
from gradio.exceptions import RenderError
|
||||
from gradio.flagging import CSVLogger, FlaggingCallback, FlagMethod
|
||||
from gradio.layouts import Column, Row, Tab, Tabs
|
||||
from gradio.pipelines import load_from_pipeline
|
||||
@ -449,7 +450,8 @@ class Interface(Blocks):
|
||||
stop_btn = stop_btn or stop_btn_2_out
|
||||
flag_btns = flag_btns or flag_btns_out
|
||||
|
||||
assert clear_btn is not None, "Clear button not rendered"
|
||||
if clear_btn is None:
|
||||
raise RenderError("Clear button not rendered")
|
||||
|
||||
self.attach_submit_events(submit_btn, stop_btn)
|
||||
self.attach_clear_events(
|
||||
@ -586,7 +588,8 @@ class Interface(Blocks):
|
||||
if self.allow_flagging == "manual":
|
||||
flag_btns = self.render_flag_btns()
|
||||
elif self.allow_flagging == "auto":
|
||||
assert submit_btn is not None, "Submit button not rendered"
|
||||
if submit_btn is None:
|
||||
raise RenderError("Submit button not rendered")
|
||||
flag_btns = [submit_btn]
|
||||
|
||||
if self.interpretation:
|
||||
@ -611,7 +614,8 @@ class Interface(Blocks):
|
||||
def attach_submit_events(self, submit_btn: Button | None, stop_btn: Button | None):
|
||||
if self.live:
|
||||
if self.interface_type == InterfaceTypes.OUTPUT_ONLY:
|
||||
assert submit_btn is not None, "Submit button not rendered"
|
||||
if submit_btn is None:
|
||||
raise RenderError("Submit button not rendered")
|
||||
super().load(self.fn, None, self.output_components)
|
||||
# For output-only interfaces, the user probably still want a "generate"
|
||||
# button even if the Interface is live
|
||||
@ -642,7 +646,8 @@ class Interface(Blocks):
|
||||
postprocess=not (self.api_mode),
|
||||
)
|
||||
else:
|
||||
assert submit_btn is not None, "Submit button not rendered"
|
||||
if submit_btn is None:
|
||||
raise RenderError("Submit button not rendered")
|
||||
fn = self.fn
|
||||
extra_output = []
|
||||
|
||||
|
@ -230,7 +230,8 @@ async def run_interpret(interface: Interface, raw_input: list):
|
||||
nsamples=int(interface.num_shap * num_total_segments),
|
||||
silent=True,
|
||||
)
|
||||
assert shap_values is not None, "SHAP values could not be calculated"
|
||||
if shap_values is None:
|
||||
raise ValueError("SHAP values could not be calculated")
|
||||
scores.append(
|
||||
input_component.get_interpretation_scores(
|
||||
raw_input[i],
|
||||
|
@ -368,7 +368,8 @@ class Queue:
|
||||
|
||||
async def call_prediction(self, events: list[Event], batch: bool):
|
||||
body = events[0].data
|
||||
assert body is not None, "No event data"
|
||||
if body is None:
|
||||
raise ValueError("No event data")
|
||||
username = events[0].username
|
||||
body.event_id = events[0]._id if not batch else None
|
||||
try:
|
||||
|
@ -57,7 +57,10 @@ class RangedFileResponse(Response):
|
||||
stat_result: os.stat_result | None = None,
|
||||
method: str | None = None,
|
||||
) -> None:
|
||||
assert aiofiles is not None, "'aiofiles' must be installed to use FileResponse"
|
||||
if aiofiles is None:
|
||||
raise ModuleNotFoundError(
|
||||
"'aiofiles' must be installed to use FileResponse"
|
||||
)
|
||||
self.path = path
|
||||
self.range = range
|
||||
self.filename = filename
|
||||
|
@ -339,11 +339,11 @@ def assert_configs_are_equivalent_besides_ids(
|
||||
config2 = json.loads(json.dumps(config2))
|
||||
|
||||
for key in root_keys:
|
||||
assert config1[key] == config2[key], f"Configs have different: {key}"
|
||||
if config1[key] != config2[key]:
|
||||
raise ValueError(f"Configs have different: {key}")
|
||||
|
||||
assert len(config1["components"]) == len(
|
||||
config2["components"]
|
||||
), "# of components are different"
|
||||
if len(config1["components"]) != len(config2["components"]):
|
||||
raise ValueError("# of components are different")
|
||||
|
||||
def assert_same_components(config1_id, config2_id):
|
||||
c1 = list(filter(lambda c: c["id"] == config1_id, config1["components"]))
|
||||
@ -358,7 +358,8 @@ def assert_configs_are_equivalent_besides_ids(
|
||||
c1.pop("id")
|
||||
c2 = copy.deepcopy(c2)
|
||||
c2.pop("id")
|
||||
assert c1 == c2, f"{c1} does not match {c2}"
|
||||
if c1 != c2:
|
||||
raise ValueError(f"{c1} does not match {c2}")
|
||||
|
||||
def same_children_recursive(children1, chidren2):
|
||||
for child1, child2 in zip(children1, chidren2):
|
||||
@ -378,7 +379,8 @@ def assert_configs_are_equivalent_besides_ids(
|
||||
for o1, o2 in zip(d1.pop("outputs"), d2.pop("outputs")):
|
||||
assert_same_components(o1, o2)
|
||||
|
||||
assert d1 == d2, f"{d1} does not match {d2}"
|
||||
if d1 != d2:
|
||||
raise ValueError(f"{d1} does not match {d2}")
|
||||
|
||||
return True
|
||||
|
||||
|
@ -115,7 +115,7 @@ def test_assert_configs_are_equivalent():
|
||||
|
||||
assert assert_configs_are_equivalent_besides_ids(xray_config, xray_config)
|
||||
assert assert_configs_are_equivalent_besides_ids(xray_config, xray_config_diff_ids)
|
||||
with pytest.raises(AssertionError):
|
||||
with pytest.raises(ValueError):
|
||||
assert_configs_are_equivalent_besides_ids(xray_config, xray_config_wrong)
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user