Assert refactor in external.py (#5811)

* Refactored assert statements to if statements

* format-addons

* format

* add changeset

* Update gradio/external.py

Co-authored-by: Abubakar Abid <abubakar@huggingface.co>

* refactored video.py

* refactored all the assert statements with response

* add changeset

* add changeset

* Apply suggestions from code review

* Refactored documentation.py and few more files

* avoid circular

* Replaced all assert statements

* lint

* notebooks

* fix

* minor changes

* final changes according to tests

* Lint

* last fix

* fix

* fix utils test

* fix serialization error

* fix serialization error

---------

Co-authored-by: harry-urek <hariombhardwaj038@gmail.com>
Co-authored-by: gradio-pr-bot <gradio-pr-bot@users.noreply.github.com>
Co-authored-by: Abubakar Abid <abubakar@huggingface.co>
This commit is contained in:
Hari Om Bhardwaj 2023-10-07 06:44:14 +05:30 committed by GitHub
parent 48e09ee887
commit 1d5b15a2d2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
22 changed files with 198 additions and 120 deletions

View File

@ -0,0 +1,6 @@
---
"gradio": patch
"gradio_client": patch
---
fix:Assert refactor in external.py

View File

@ -31,6 +31,7 @@ from packaging import version
from gradio_client import serializing, utils
from gradio_client.documentation import document, set_documentation_group
from gradio_client.exceptions import SerializationSetupError
from gradio_client.serializing import Serializable
from gradio_client.utils import (
Communicator,
@ -646,9 +647,8 @@ class Client:
raise ValueError(
f"Each entry in api_names must be either a string or a tuple of strings. Received {api_names}"
)
assert (
len(api_names) == 1
), "Currently only one api_name can be deployed to discord."
if len(api_names) != 1:
raise ValueError("Currently only one api_name can be deployed to discord.")
for i, name in enumerate(api_names):
if isinstance(name, str):
@ -676,8 +676,8 @@ class Client:
is_private = False
if self.space_id:
is_private = huggingface_hub.space_info(self.space_id).private
if is_private:
assert hf_token, (
if is_private and not hf_token:
raise ValueError(
f"Since {self.space_id} is private, you must explicitly pass in hf_token "
"so that it can be added as a secret in the discord bot space."
)
@ -777,7 +777,7 @@ class Endpoint:
# and api_name is not False (meaning that the developer has explicitly disabled the API endpoint)
self.serializers, self.deserializers = self._setup_serializers()
self.is_valid = self.dependency["backend_fn"] and self.api_name is not False
except AssertionError:
except SerializationSetupError:
self.is_valid = False
def __repr__(self):
@ -952,9 +952,10 @@ class Endpoint:
return data
def serialize(self, *data) -> tuple:
assert len(data) == len(
self.serializers
), f"Expected {len(self.serializers)} arguments, got {len(data)}"
if len(data) != len(self.serializers):
raise ValueError(
f"Expected {len(self.serializers)} arguments, got {len(data)}"
)
files = [
f
@ -968,9 +969,10 @@ class Endpoint:
return o
def deserialize(self, *data) -> tuple:
assert len(data) == len(
self.deserializers
), f"Expected {len(self.deserializers)} outputs, got {len(data)}"
if len(data) != len(self.deserializers):
raise ValueError(
f"Expected {len(self.deserializers)} outputs, got {len(data)}"
)
outputs = tuple(
[
s.deserialize(
@ -1002,15 +1004,17 @@ class Endpoint:
self.input_component_types.append(component_name)
if component.get("serializer"):
serializer_name = component["serializer"]
assert (
serializer_name in serializing.SERIALIZER_MAPPING
), f"Unknown serializer: {serializer_name}, you may need to update your gradio_client version."
if serializer_name not in serializing.SERIALIZER_MAPPING:
raise SerializationSetupError(
f"Unknown serializer: {serializer_name}, you may need to update your gradio_client version."
)
serializer = serializing.SERIALIZER_MAPPING[serializer_name]
else:
assert (
component_name in serializing.COMPONENT_MAPPING
), f"Unknown component: {component_name}, you may need to update your gradio_client version."
elif component_name in serializing.COMPONENT_MAPPING:
serializer = serializing.COMPONENT_MAPPING[component_name]
else:
raise SerializationSetupError(
f"Unknown component: {component_name}, you may need to update your gradio_client version."
)
serializers.append(serializer()) # type: ignore
outputs = self.dependency["outputs"]
@ -1022,17 +1026,19 @@ class Endpoint:
self.output_component_types.append(component_name)
if component.get("serializer"):
serializer_name = component["serializer"]
assert (
serializer_name in serializing.SERIALIZER_MAPPING
), f"Unknown serializer: {serializer_name}, you may need to update your gradio_client version."
if serializer_name not in serializing.SERIALIZER_MAPPING:
raise SerializationSetupError(
f"Unknown serializer: {serializer_name}, you may need to update your gradio_client version."
)
deserializer = serializing.SERIALIZER_MAPPING[serializer_name]
elif component_name in utils.SKIP_COMPONENTS:
deserializer = serializing.SimpleSerializable
else:
assert (
component_name in serializing.COMPONENT_MAPPING
), f"Unknown component: {component_name}, you may need to update your gradio_client version."
elif component_name in serializing.COMPONENT_MAPPING:
deserializer = serializing.COMPONENT_MAPPING[component_name]
else:
raise SerializationSetupError(
f"Unknown component: {component_name}, you may need to update your gradio_client version."
)
deserializers.append(deserializer()) # type: ignore
return serializers, deserializers

View File

@ -26,16 +26,18 @@ def extract_instance_attr_doc(cls, attr):
"self." + attr + " ="
):
break
assert i is not None, f"Could not find {attr} in {cls.__name__}"
if i is None:
raise NameError(f"Could not find {attr} in {cls.__name__}")
start_line = lines.index('"""', i)
end_line = lines.index('"""', start_line + 1)
for j in range(i + 1, start_line):
assert not lines[j].startswith("self."), (
f"Found another attribute before docstring for {attr} in {cls.__name__}: "
+ lines[j]
+ "\n start:"
+ lines[i]
)
if lines[j].startswith("self."):
raise ValueError(
f"Found another attribute before docstring for {attr} in {cls.__name__}: "
+ lines[j]
+ "\n start:"
+ lines[i]
)
doc_string = " ".join(lines[start_line + 1 : end_line])
return doc_string
@ -95,15 +97,17 @@ def document_fn(fn: Callable, cls) -> tuple[str, list[dict], dict, str | None]:
continue
if not (line.startswith(" ") or line.strip() == ""):
print(line)
assert (
line.startswith(" ") or line.strip() == ""
), f"Documentation format for {fn.__name__} has format error in line: {line}"
if not (line.startswith(" ") or line.strip() == ""):
raise SyntaxError(
f"Documentation format for {fn.__name__} has format error in line: {line}"
)
line = line[4:]
if mode == "parameter":
colon_index = line.index(": ")
assert (
colon_index > -1
), f"Documentation format for {fn.__name__} has format error in line: {line}"
if colon_index < -1:
raise SyntaxError(
f"Documentation format for {fn.__name__} has format error in line: {line}"
)
parameter = line[:colon_index]
parameter_doc = line[colon_index + 2 :]
parameters[parameter] = parameter_doc
@ -172,9 +176,10 @@ def document_cls(cls):
if mode == "description":
description_lines.append(line if line.strip() else "<br>")
else:
assert (
line.startswith(" ") or not line.strip()
), f"Documentation format for {cls.__name__} has format error in line: {line}"
if not (line.startswith(" ") or not line.strip()):
raise SyntaxError(
f"Documentation format for {cls.__name__} has format error in line: {line}"
)
tags[mode].append(line[4:])
if "example" in tags:
example = "\n".join(tags["example"])

View File

@ -0,0 +1,4 @@
class SerializationSetupError(ValueError):
"""Raised when a serializers cannot be set up correctly."""
pass

View File

@ -307,7 +307,8 @@ class FileSerializable(Serializable):
elif isinstance(x, dict):
if x.get("is_file"):
filepath = x.get("name")
assert filepath is not None, f"The 'name' field is missing in {x}"
if filepath is None:
raise ValueError(f"The 'name' field is missing in {x}")
if root_url is not None:
file_name = utils.download_tmp_copy_of_file(
root_url + "file=" + filepath,
@ -331,7 +332,8 @@ class FileSerializable(Serializable):
file_name = str(path)
else:
data = x.get("data")
assert data is not None, f"The 'data' field is missing in {x}"
if data is None:
raise ValueError(f"The 'data' field is missing in {x}")
file_name = utils.decode_base64_to_file(data, dir=save_dir).name
else:
raise ValueError(
@ -426,7 +428,8 @@ class VideoSerializable(FileSerializable):
version (string filepath). Optionally, save the file to the directory specified by `save_dir`
"""
if isinstance(x, (tuple, list)):
assert len(x) == 2, f"Expected tuple of length 2. Received: {x}"
if len(x) != 2:
raise ValueError(f"Expected tuple of length 2. Received: {x}")
x_as_list = [x[0], x[1]]
else:
raise ValueError(f"Expected tuple of length 2. Received: {x}")

File diff suppressed because one or more lines are too long

View File

@ -20,7 +20,8 @@ FIGSIZE = 7, 7 # does not affect size in webpage
COLORS = [
'blue', 'orange', 'green', 'red', 'purple', 'brown', 'pink', 'gray', 'olive', 'cyan'
]
assert len(COLORS) >= MAX_CLUSTERS, "Not enough different colors for all clusters"
if len(COLORS) <= MAX_CLUSTERS:
raise ValueError("Not enough different colors for all clusters")
np.random.seed(SEED)

View File

@ -44,6 +44,7 @@ from gradio.exceptions import (
DuplicateBlockError,
InvalidApiNameError,
InvalidBlockError,
InvalidComponentError,
)
from gradio.helpers import EventData, create_tracker, skip, special_args
from gradio.state_holder import SessionState
@ -364,9 +365,10 @@ def postprocess_update_dict(block: Block, update_dict: dict, postprocess: bool =
attr_dict["__type__"] = "update"
attr_dict.pop("value", None)
if "value" in update_dict:
assert isinstance(
block, components.IOComponent
), f"Component {block.__class__} does not support value"
if not isinstance(block, components.IOComponent):
raise InvalidComponentError(
f"Component {block.__class__} does not support value"
)
if postprocess:
attr_dict["value"] = block.postprocess(update_dict["value"])
else:
@ -766,9 +768,10 @@ class Blocks(BlockContext):
children = child_config.get("children")
if children is not None:
assert isinstance(
block, BlockContext
), f"Invalid config, Block with id {id} has children but is not a BlockContext."
if not isinstance(block, BlockContext):
raise ValueError(
f"Invalid config, Block with id {id} has children but is not a BlockContext."
)
with block:
iterate_over_children(children)
@ -1158,7 +1161,8 @@ class Blocks(BlockContext):
event_data: data associated with event trigger
"""
block_fn = self.fns[fn_index]
assert block_fn.fn, f"function with index {fn_index} not defined."
if not block_fn.fn:
raise IndexError(f"function with index {fn_index} not defined.")
is_generating = False
request = requests[0] if isinstance(requests, list) else requests
start = time.time()
@ -1234,9 +1238,10 @@ class Blocks(BlockContext):
raise InvalidBlockError(
f"Input component with id {input_id} used in {dependency['trigger']}() event is not defined in this gr.Blocks context. You are allowed to nest gr.Blocks contexts, but there must be a gr.Blocks context that contains all components and events."
) from e
assert isinstance(
block, components.IOComponent
), f"{block.__class__} Component with id {input_id} not a valid input component."
if not isinstance(block, components.IOComponent):
raise InvalidComponentError(
f"{block.__class__} Component with id {input_id} not a valid input component."
)
serialized_input = block.serialize(inputs[i])
processed_input.append(serialized_input)
@ -1253,9 +1258,10 @@ class Blocks(BlockContext):
raise InvalidBlockError(
f"Output component with id {output_id} used in {dependency['trigger']}() event not found in this gr.Blocks context. You are allowed to nest gr.Blocks contexts, but there must be a gr.Blocks context that contains all components and events."
) from e
assert isinstance(
block, components.IOComponent
), f"{block.__class__} Component with id {output_id} not a valid output component."
if not isinstance(block, components.IOComponent):
raise InvalidComponentError(
f"{block.__class__} Component with id {output_id} not a valid output component."
)
deserialized = block.deserialize(
outputs[o],
save_dir=block.DEFAULT_TEMP_DIR,
@ -1322,9 +1328,10 @@ Received inputs:
raise InvalidBlockError(
f"Input component with id {input_id} used in {dependency['trigger']}() event not found in this gr.Blocks context. You are allowed to nest gr.Blocks contexts, but there must be a gr.Blocks context that contains all components and events."
) from e
assert isinstance(
block, components.Component
), f"{block.__class__} Component with id {input_id} not a valid input component."
if not isinstance(block, components.Component):
raise InvalidComponentError(
f"{block.__class__} Component with id {input_id} not a valid input component."
)
if getattr(block, "stateful", False):
processed_input.append(state[input_id])
else:
@ -1445,9 +1452,10 @@ Received outputs:
postprocess=block_fn.postprocess,
)
elif block_fn.postprocess:
assert isinstance(
block, components.Component
), f"{block.__class__} Component with id {output_id} not a valid output component."
if not isinstance(block, components.Component):
raise InvalidComponentError(
f"{block.__class__} Component with id {output_id} not a valid output component."
)
prediction_value = block.postprocess(prediction_value)
output.append(prediction_value)
@ -2005,9 +2013,8 @@ Received outputs:
)
if self.is_running:
assert isinstance(
self.local_url, str
), f"Invalid local_url: {self.local_url}"
if not isinstance(self.local_url, str):
raise ValueError(f"Invalid local_url: {self.local_url}")
if not (quiet):
print(
"Rerunning server... use `close()` to stop if you need to change `launch()` parameters.\n----"

View File

@ -205,12 +205,14 @@ class Chatbot(Changeable, Selectable, Likeable, IOComponent, JSONSerializable):
return y
processed_messages = []
for message_pair in y:
assert isinstance(
message_pair, (tuple, list)
), f"Expected a list of lists or list of tuples. Received: {message_pair}"
assert (
len(message_pair) == 2
), f"Expected a list of lists of length 2 or list of tuples of length 2. Received: {message_pair}"
if not isinstance(message_pair, (tuple, list)):
raise TypeError(
f"Expected a list of lists or list of tuples. Received: {message_pair}"
)
if len(message_pair) != 2:
raise TypeError(
f"Expected a list of lists of length 2 or list of tuples of length 2. Received: {message_pair}"
)
processed_messages.append(
[
self._preprocess_chat_messages(message_pair[0]),
@ -259,12 +261,14 @@ class Chatbot(Changeable, Selectable, Likeable, IOComponent, JSONSerializable):
return []
processed_messages = []
for message_pair in y:
assert isinstance(
message_pair, (tuple, list)
), f"Expected a list of lists or list of tuples. Received: {message_pair}"
assert (
len(message_pair) == 2
), f"Expected a list of lists of length 2 or list of tuples of length 2. Received: {message_pair}"
if not isinstance(message_pair, (tuple, list)):
raise TypeError(
f"Expected a list of lists or list of tuples. Received: {message_pair}"
)
if len(message_pair) != 2:
raise TypeError(
f"Expected a list of lists of length 2 or list of tuples of length 2. Received: {message_pair}"
)
processed_messages.append(
[
self._postprocess_chat_messages(message_pair[0]),

View File

@ -81,7 +81,9 @@ class Code(Changeable, Inputable, IOComponent, StringSerializable):
elem_id: An optional string that is assigned as the id of this component in the HTML DOM. Can be used for targeting CSS styles.
elem_classes: An optional list of strings that are assigned as the classes of this component in the HTML DOM. Can be used for targeting CSS styles.
"""
assert language in Code.languages, f"Language {language} not supported."
if language not in Code.languages:
raise ValueError(f"Language {language} not supported.")
self.language = language
self.lines = lines
IOComponent.__init__(

View File

@ -280,7 +280,8 @@ class Dataframe(Changeable, Inputable, Selectable, IOComponent, JSONSerializable
return self.postprocess([[]])
if isinstance(y, np.ndarray):
y = y.tolist()
assert isinstance(y, list), "output cannot be converted to list"
if not isinstance(y, list):
raise ValueError("output cannot be converted to list")
_headers = self.headers
if len(self.headers) < len(y[0]):

View File

@ -68,9 +68,10 @@ class Dataset(Clickable, Selectable, Component, StringSerializable):
self._components = [get_component_instance(c) for c in components]
# Narrow type to IOComponent
assert all(
isinstance(c, IOComponent) for c in self._components
), "All components in a `Dataset` must be subclasses of `IOComponent`"
if not all(isinstance(c, IOComponent) for c in self._components):
raise ValueError(
"All components in a `Dataset` must be subclasses of `IOComponent`"
)
self._components = [c for c in self._components if isinstance(c, IOComponent)]
for component in self._components:
component.root_url = self.root_url

View File

@ -193,14 +193,16 @@ class Video(
)
if is_file:
assert file_name is not None, "Received file data without a file name."
if file_name is None:
raise ValueError("Received file data without a file name.")
if client_utils.is_http_url_like(file_name):
fn = self.download_temp_copy_if_needed
else:
fn = self.make_temp_copy_if_needed
file_name = Path(fn(file_name))
else:
assert file_data is not None, "Received empty file data."
if file_data is None:
raise ValueError("Received empty file data.")
file_name = Path(self.base64_to_temp_file_if_needed(file_data, file_name))
uploaded_format = file_name.suffix.replace(".", "")
@ -270,12 +272,15 @@ class Video(
if isinstance(y, (str, Path)):
processed_files = (self._format_video(y), None)
elif isinstance(y, (tuple, list)):
assert (
len(y) == 2
), f"Expected lists of length 2 or tuples of length 2. Received: {y}"
assert isinstance(y[0], (str, Path)) and isinstance(
y[1], (str, Path)
), f"If a tuple is provided, both elements must be strings or Path objects. Received: {y}"
if len(y) != 2:
raise ValueError(
f"Expected lists of length 2 or tuples of length 2. Received: {y}"
)
if not (isinstance(y[0], (str, Path)) and isinstance(y[1], (str, Path))):
raise TypeError(
f"If a tuple is provided, both elements must be strings or Path objects. Received: {y}"
)
video = y[0]
subtitle = y[1]
processed_files = (

View File

@ -9,12 +9,30 @@ class DuplicateBlockError(ValueError):
pass
class InvalidComponentError(ValueError):
"""Raised when invalid components are used."""
pass
class TooManyRequestsError(Exception):
"""Raised when the Hugging Face API returns a 429 status code."""
pass
class ModelNotFoundError(Exception):
"""Raised when the provided model doesn't exists or is not found by the provided api url."""
pass
class RenderError(Exception):
"""Raised when a component has not been rendered in the current Blocks but is expected to have been rendered."""
pass
class InvalidApiNameError(ValueError):
pass

View File

@ -16,7 +16,7 @@ import gradio
from gradio import components, utils
from gradio.context import Context
from gradio.deprecation import warn_deprecation
from gradio.exceptions import Error, TooManyRequestsError
from gradio.exceptions import Error, ModelNotFoundError, TooManyRequestsError
from gradio.external_utils import (
cols_to_rows,
encode_to_base64,
@ -83,9 +83,10 @@ def load_blocks_from_repo(
if src is None:
# Separate the repo type (e.g. "model") from repo name (e.g. "google/vit-base-patch16-224")
tokens = name.split("/")
assert (
len(tokens) > 1
), "Either `src` parameter must be provided, or `name` must be formatted as {src}/{repo name}"
if len(tokens) <= 1:
raise ValueError(
"Either `src` parameter must be provided, or `name` must be formatted as {src}/{repo name}"
)
src = tokens[0]
name = "/".join(tokens[1:])
@ -95,9 +96,8 @@ def load_blocks_from_repo(
"models": from_model,
"spaces": from_spaces,
}
assert (
src.lower() in factory_methods
), f"parameter: src must be one of {factory_methods.keys()}"
if src.lower() not in factory_methods:
raise ValueError(f"parameter: src must be one of {factory_methods.keys()}")
if hf_token is not None:
if Context.hf_token is not None and Context.hf_token != hf_token:
@ -145,9 +145,10 @@ def from_model(model_name: str, hf_token: str | None, alias: str | None, **kwarg
# Checking if model exists, and if so, it gets the pipeline
response = requests.request("GET", api_url, headers=headers)
assert (
response.status_code == 200
), f"Could not find model: {model_name}. If it is a private or gated model, please provide your Hugging Face access token (https://huggingface.co/settings/tokens) as the argument for the `api_key` parameter."
if response.status_code != 200:
raise ModelNotFoundError(
f"Could not find model: {model_name}. If it is a private or gated model, please provide your Hugging Face access token (https://huggingface.co/settings/tokens) as the argument for the `api_key` parameter."
)
p = response.json().get("pipeline_tag")
pipelines = {
"audio-classification": {

View File

@ -525,7 +525,8 @@ class Progress(Iterable):
):
current_iterable = self.iterables.pop()
callback(self.iterables)
assert current_iterable.index is not None, "Index not set."
if current_iterable.index is None:
raise IndexError("Index not set.")
current_iterable.index += 1
try:
return next(current_iterable.iterable) # type: ignore
@ -603,7 +604,8 @@ class Progress(Iterable):
callback = self._progress_callback()
if callback and len(self.iterables) > 0:
current_iterable = self.iterables[-1]
assert current_iterable.index is not None, "Index not set."
if current_iterable.index is None:
raise IndexError("Index not set.")
current_iterable.index += n
callback(self.iterables)
else:

View File

@ -28,6 +28,7 @@ from gradio.components import (
from gradio.data_classes import InterfaceTypes
from gradio.deprecation import warn_deprecation
from gradio.events import Changeable, Streamable, Submittable, on
from gradio.exceptions import RenderError
from gradio.flagging import CSVLogger, FlaggingCallback, FlagMethod
from gradio.layouts import Column, Row, Tab, Tabs
from gradio.pipelines import load_from_pipeline
@ -449,7 +450,8 @@ class Interface(Blocks):
stop_btn = stop_btn or stop_btn_2_out
flag_btns = flag_btns or flag_btns_out
assert clear_btn is not None, "Clear button not rendered"
if clear_btn is None:
raise RenderError("Clear button not rendered")
self.attach_submit_events(submit_btn, stop_btn)
self.attach_clear_events(
@ -586,7 +588,8 @@ class Interface(Blocks):
if self.allow_flagging == "manual":
flag_btns = self.render_flag_btns()
elif self.allow_flagging == "auto":
assert submit_btn is not None, "Submit button not rendered"
if submit_btn is None:
raise RenderError("Submit button not rendered")
flag_btns = [submit_btn]
if self.interpretation:
@ -611,7 +614,8 @@ class Interface(Blocks):
def attach_submit_events(self, submit_btn: Button | None, stop_btn: Button | None):
if self.live:
if self.interface_type == InterfaceTypes.OUTPUT_ONLY:
assert submit_btn is not None, "Submit button not rendered"
if submit_btn is None:
raise RenderError("Submit button not rendered")
super().load(self.fn, None, self.output_components)
# For output-only interfaces, the user probably still want a "generate"
# button even if the Interface is live
@ -642,7 +646,8 @@ class Interface(Blocks):
postprocess=not (self.api_mode),
)
else:
assert submit_btn is not None, "Submit button not rendered"
if submit_btn is None:
raise RenderError("Submit button not rendered")
fn = self.fn
extra_output = []

View File

@ -230,7 +230,8 @@ async def run_interpret(interface: Interface, raw_input: list):
nsamples=int(interface.num_shap * num_total_segments),
silent=True,
)
assert shap_values is not None, "SHAP values could not be calculated"
if shap_values is None:
raise ValueError("SHAP values could not be calculated")
scores.append(
input_component.get_interpretation_scores(
raw_input[i],

View File

@ -368,7 +368,8 @@ class Queue:
async def call_prediction(self, events: list[Event], batch: bool):
body = events[0].data
assert body is not None, "No event data"
if body is None:
raise ValueError("No event data")
username = events[0].username
body.event_id = events[0]._id if not batch else None
try:

View File

@ -57,7 +57,10 @@ class RangedFileResponse(Response):
stat_result: os.stat_result | None = None,
method: str | None = None,
) -> None:
assert aiofiles is not None, "'aiofiles' must be installed to use FileResponse"
if aiofiles is None:
raise ModuleNotFoundError(
"'aiofiles' must be installed to use FileResponse"
)
self.path = path
self.range = range
self.filename = filename

View File

@ -339,11 +339,11 @@ def assert_configs_are_equivalent_besides_ids(
config2 = json.loads(json.dumps(config2))
for key in root_keys:
assert config1[key] == config2[key], f"Configs have different: {key}"
if config1[key] != config2[key]:
raise ValueError(f"Configs have different: {key}")
assert len(config1["components"]) == len(
config2["components"]
), "# of components are different"
if len(config1["components"]) != len(config2["components"]):
raise ValueError("# of components are different")
def assert_same_components(config1_id, config2_id):
c1 = list(filter(lambda c: c["id"] == config1_id, config1["components"]))
@ -358,7 +358,8 @@ def assert_configs_are_equivalent_besides_ids(
c1.pop("id")
c2 = copy.deepcopy(c2)
c2.pop("id")
assert c1 == c2, f"{c1} does not match {c2}"
if c1 != c2:
raise ValueError(f"{c1} does not match {c2}")
def same_children_recursive(children1, chidren2):
for child1, child2 in zip(children1, chidren2):
@ -378,7 +379,8 @@ def assert_configs_are_equivalent_besides_ids(
for o1, o2 in zip(d1.pop("outputs"), d2.pop("outputs")):
assert_same_components(o1, o2)
assert d1 == d2, f"{d1} does not match {d2}"
if d1 != d2:
raise ValueError(f"{d1} does not match {d2}")
return True

View File

@ -115,7 +115,7 @@ def test_assert_configs_are_equivalent():
assert assert_configs_are_equivalent_besides_ids(xray_config, xray_config)
assert assert_configs_are_equivalent_besides_ids(xray_config, xray_config_diff_ids)
with pytest.raises(AssertionError):
with pytest.raises(ValueError):
assert_configs_are_equivalent_besides_ids(xray_config, xray_config_wrong)