diff --git a/.changeset/light-buses-enter.md b/.changeset/light-buses-enter.md new file mode 100644 index 0000000000..4f40a78578 --- /dev/null +++ b/.changeset/light-buses-enter.md @@ -0,0 +1,6 @@ +--- +"gradio": patch +"gradio_client": patch +--- + +fix:Assert refactor in external.py diff --git a/client/python/gradio_client/client.py b/client/python/gradio_client/client.py index bccb303fb4..f6a734512f 100644 --- a/client/python/gradio_client/client.py +++ b/client/python/gradio_client/client.py @@ -31,6 +31,7 @@ from packaging import version from gradio_client import serializing, utils from gradio_client.documentation import document, set_documentation_group +from gradio_client.exceptions import SerializationSetupError from gradio_client.serializing import Serializable from gradio_client.utils import ( Communicator, @@ -646,9 +647,8 @@ class Client: raise ValueError( f"Each entry in api_names must be either a string or a tuple of strings. Received {api_names}" ) - assert ( - len(api_names) == 1 - ), "Currently only one api_name can be deployed to discord." + if len(api_names) != 1: + raise ValueError("Currently only one api_name can be deployed to discord.") for i, name in enumerate(api_names): if isinstance(name, str): @@ -676,8 +676,8 @@ class Client: is_private = False if self.space_id: is_private = huggingface_hub.space_info(self.space_id).private - if is_private: - assert hf_token, ( + if is_private and not hf_token: + raise ValueError( f"Since {self.space_id} is private, you must explicitly pass in hf_token " "so that it can be added as a secret in the discord bot space." ) @@ -777,7 +777,7 @@ class Endpoint: # and api_name is not False (meaning that the developer has explicitly disabled the API endpoint) self.serializers, self.deserializers = self._setup_serializers() self.is_valid = self.dependency["backend_fn"] and self.api_name is not False - except AssertionError: + except SerializationSetupError: self.is_valid = False def __repr__(self): @@ -952,9 +952,10 @@ class Endpoint: return data def serialize(self, *data) -> tuple: - assert len(data) == len( - self.serializers - ), f"Expected {len(self.serializers)} arguments, got {len(data)}" + if len(data) != len(self.serializers): + raise ValueError( + f"Expected {len(self.serializers)} arguments, got {len(data)}" + ) files = [ f @@ -968,9 +969,10 @@ class Endpoint: return o def deserialize(self, *data) -> tuple: - assert len(data) == len( - self.deserializers - ), f"Expected {len(self.deserializers)} outputs, got {len(data)}" + if len(data) != len(self.deserializers): + raise ValueError( + f"Expected {len(self.deserializers)} outputs, got {len(data)}" + ) outputs = tuple( [ s.deserialize( @@ -1002,15 +1004,17 @@ class Endpoint: self.input_component_types.append(component_name) if component.get("serializer"): serializer_name = component["serializer"] - assert ( - serializer_name in serializing.SERIALIZER_MAPPING - ), f"Unknown serializer: {serializer_name}, you may need to update your gradio_client version." + if serializer_name not in serializing.SERIALIZER_MAPPING: + raise SerializationSetupError( + f"Unknown serializer: {serializer_name}, you may need to update your gradio_client version." + ) serializer = serializing.SERIALIZER_MAPPING[serializer_name] - else: - assert ( - component_name in serializing.COMPONENT_MAPPING - ), f"Unknown component: {component_name}, you may need to update your gradio_client version." + elif component_name in serializing.COMPONENT_MAPPING: serializer = serializing.COMPONENT_MAPPING[component_name] + else: + raise SerializationSetupError( + f"Unknown component: {component_name}, you may need to update your gradio_client version." + ) serializers.append(serializer()) # type: ignore outputs = self.dependency["outputs"] @@ -1022,17 +1026,19 @@ class Endpoint: self.output_component_types.append(component_name) if component.get("serializer"): serializer_name = component["serializer"] - assert ( - serializer_name in serializing.SERIALIZER_MAPPING - ), f"Unknown serializer: {serializer_name}, you may need to update your gradio_client version." + if serializer_name not in serializing.SERIALIZER_MAPPING: + raise SerializationSetupError( + f"Unknown serializer: {serializer_name}, you may need to update your gradio_client version." + ) deserializer = serializing.SERIALIZER_MAPPING[serializer_name] elif component_name in utils.SKIP_COMPONENTS: deserializer = serializing.SimpleSerializable - else: - assert ( - component_name in serializing.COMPONENT_MAPPING - ), f"Unknown component: {component_name}, you may need to update your gradio_client version." + elif component_name in serializing.COMPONENT_MAPPING: deserializer = serializing.COMPONENT_MAPPING[component_name] + else: + raise SerializationSetupError( + f"Unknown component: {component_name}, you may need to update your gradio_client version." + ) deserializers.append(deserializer()) # type: ignore return serializers, deserializers diff --git a/client/python/gradio_client/documentation.py b/client/python/gradio_client/documentation.py index 4d8d41ddf8..53b0643688 100644 --- a/client/python/gradio_client/documentation.py +++ b/client/python/gradio_client/documentation.py @@ -26,16 +26,18 @@ def extract_instance_attr_doc(cls, attr): "self." + attr + " =" ): break - assert i is not None, f"Could not find {attr} in {cls.__name__}" + if i is None: + raise NameError(f"Could not find {attr} in {cls.__name__}") start_line = lines.index('"""', i) end_line = lines.index('"""', start_line + 1) for j in range(i + 1, start_line): - assert not lines[j].startswith("self."), ( - f"Found another attribute before docstring for {attr} in {cls.__name__}: " - + lines[j] - + "\n start:" - + lines[i] - ) + if lines[j].startswith("self."): + raise ValueError( + f"Found another attribute before docstring for {attr} in {cls.__name__}: " + + lines[j] + + "\n start:" + + lines[i] + ) doc_string = " ".join(lines[start_line + 1 : end_line]) return doc_string @@ -95,15 +97,17 @@ def document_fn(fn: Callable, cls) -> tuple[str, list[dict], dict, str | None]: continue if not (line.startswith(" ") or line.strip() == ""): print(line) - assert ( - line.startswith(" ") or line.strip() == "" - ), f"Documentation format for {fn.__name__} has format error in line: {line}" + if not (line.startswith(" ") or line.strip() == ""): + raise SyntaxError( + f"Documentation format for {fn.__name__} has format error in line: {line}" + ) line = line[4:] if mode == "parameter": colon_index = line.index(": ") - assert ( - colon_index > -1 - ), f"Documentation format for {fn.__name__} has format error in line: {line}" + if colon_index < -1: + raise SyntaxError( + f"Documentation format for {fn.__name__} has format error in line: {line}" + ) parameter = line[:colon_index] parameter_doc = line[colon_index + 2 :] parameters[parameter] = parameter_doc @@ -172,9 +176,10 @@ def document_cls(cls): if mode == "description": description_lines.append(line if line.strip() else "
") else: - assert ( - line.startswith(" ") or not line.strip() - ), f"Documentation format for {cls.__name__} has format error in line: {line}" + if not (line.startswith(" ") or not line.strip()): + raise SyntaxError( + f"Documentation format for {cls.__name__} has format error in line: {line}" + ) tags[mode].append(line[4:]) if "example" in tags: example = "\n".join(tags["example"]) diff --git a/client/python/gradio_client/exceptions.py b/client/python/gradio_client/exceptions.py new file mode 100644 index 0000000000..5329124e7a --- /dev/null +++ b/client/python/gradio_client/exceptions.py @@ -0,0 +1,4 @@ +class SerializationSetupError(ValueError): + """Raised when a serializers cannot be set up correctly.""" + + pass diff --git a/client/python/gradio_client/serializing.py b/client/python/gradio_client/serializing.py index b6548f8fc3..ca0831de8d 100644 --- a/client/python/gradio_client/serializing.py +++ b/client/python/gradio_client/serializing.py @@ -307,7 +307,8 @@ class FileSerializable(Serializable): elif isinstance(x, dict): if x.get("is_file"): filepath = x.get("name") - assert filepath is not None, f"The 'name' field is missing in {x}" + if filepath is None: + raise ValueError(f"The 'name' field is missing in {x}") if root_url is not None: file_name = utils.download_tmp_copy_of_file( root_url + "file=" + filepath, @@ -331,7 +332,8 @@ class FileSerializable(Serializable): file_name = str(path) else: data = x.get("data") - assert data is not None, f"The 'data' field is missing in {x}" + if data is None: + raise ValueError(f"The 'data' field is missing in {x}") file_name = utils.decode_base64_to_file(data, dir=save_dir).name else: raise ValueError( @@ -426,7 +428,8 @@ class VideoSerializable(FileSerializable): version (string filepath). Optionally, save the file to the directory specified by `save_dir` """ if isinstance(x, (tuple, list)): - assert len(x) == 2, f"Expected tuple of length 2. Received: {x}" + if len(x) != 2: + raise ValueError(f"Expected tuple of length 2. Received: {x}") x_as_list = [x[0], x[1]] else: raise ValueError(f"Expected tuple of length 2. Received: {x}") diff --git a/demo/clustering/run.ipynb b/demo/clustering/run.ipynb index 08fb7d0b82..0946101f93 100644 --- a/demo/clustering/run.ipynb +++ b/demo/clustering/run.ipynb @@ -1 +1 @@ -{"cells": [{"cell_type": "markdown", "id": 302934307671667531413257853548643485645, "metadata": {}, "source": ["# Gradio Demo: clustering\n", "### This demo built with Blocks generates 9 plots based on the input.\n", " "]}, {"cell_type": "code", "execution_count": null, "id": 272996653310673477252411125948039410165, "metadata": {}, "outputs": [], "source": ["!pip install -q gradio matplotlib>=3.5.2 scikit-learn>=1.0.1 "]}, {"cell_type": "code", "execution_count": null, "id": 288918539441861185822528903084949547379, "metadata": {}, "outputs": [], "source": ["import gradio as gr\n", "import math\n", "from functools import partial\n", "import matplotlib.pyplot as plt\n", "import numpy as np\n", "from sklearn.cluster import (\n", " AgglomerativeClustering, Birch, DBSCAN, KMeans, MeanShift, OPTICS, SpectralClustering, estimate_bandwidth\n", ")\n", "from sklearn.datasets import make_blobs, make_circles, make_moons\n", "from sklearn.mixture import GaussianMixture\n", "from sklearn.neighbors import kneighbors_graph\n", "from sklearn.preprocessing import StandardScaler\n", "\n", "plt.style.use('seaborn-v0_8')\n", "SEED = 0\n", "MAX_CLUSTERS = 10\n", "N_SAMPLES = 1000\n", "N_COLS = 3\n", "FIGSIZE = 7, 7 # does not affect size in webpage\n", "COLORS = [\n", " 'blue', 'orange', 'green', 'red', 'purple', 'brown', 'pink', 'gray', 'olive', 'cyan'\n", "]\n", "assert len(COLORS) >= MAX_CLUSTERS, \"Not enough different colors for all clusters\"\n", "np.random.seed(SEED)\n", "\n", "\n", "def normalize(X):\n", " return StandardScaler().fit_transform(X)\n", "\n", "def get_regular(n_clusters):\n", " # spiral pattern\n", " centers = [\n", " [0, 0],\n", " [1, 0],\n", " [1, 1],\n", " [0, 1],\n", " [-1, 1],\n", " [-1, 0],\n", " [-1, -1],\n", " [0, -1],\n", " [1, -1],\n", " [2, -1],\n", " ][:n_clusters]\n", " assert len(centers) == n_clusters\n", " X, labels = make_blobs(n_samples=N_SAMPLES, centers=centers, cluster_std=0.25, random_state=SEED)\n", " return normalize(X), labels\n", "\n", "\n", "def get_circles(n_clusters):\n", " X, labels = make_circles(n_samples=N_SAMPLES, factor=0.5, noise=0.05, random_state=SEED)\n", " return normalize(X), labels\n", "\n", "\n", "def get_moons(n_clusters):\n", " X, labels = make_moons(n_samples=N_SAMPLES, noise=0.05, random_state=SEED)\n", " return normalize(X), labels\n", "\n", "\n", "def get_noise(n_clusters):\n", " np.random.seed(SEED)\n", " X, labels = np.random.rand(N_SAMPLES, 2), np.random.randint(0, n_clusters, size=(N_SAMPLES,))\n", " return normalize(X), labels\n", "\n", "\n", "def get_anisotropic(n_clusters):\n", " X, labels = make_blobs(n_samples=N_SAMPLES, centers=n_clusters, random_state=170)\n", " transformation = [[0.6, -0.6], [-0.4, 0.8]]\n", " X = np.dot(X, transformation)\n", " return X, labels\n", "\n", "\n", "def get_varied(n_clusters):\n", " cluster_std = [1.0, 2.5, 0.5, 1.0, 2.5, 0.5, 1.0, 2.5, 0.5, 1.0][:n_clusters]\n", " assert len(cluster_std) == n_clusters\n", " X, labels = make_blobs(\n", " n_samples=N_SAMPLES, centers=n_clusters, cluster_std=cluster_std, random_state=SEED\n", " )\n", " return normalize(X), labels\n", "\n", "\n", "def get_spiral(n_clusters):\n", " # from https://scikit-learn.org/stable/auto_examples/cluster/plot_agglomerative_clustering.html\n", " np.random.seed(SEED)\n", " t = 1.5 * np.pi * (1 + 3 * np.random.rand(1, N_SAMPLES))\n", " x = t * np.cos(t)\n", " y = t * np.sin(t)\n", " X = np.concatenate((x, y))\n", " X += 0.7 * np.random.randn(2, N_SAMPLES)\n", " X = np.ascontiguousarray(X.T)\n", "\n", " labels = np.zeros(N_SAMPLES, dtype=int)\n", " return normalize(X), labels\n", "\n", "\n", "DATA_MAPPING = {\n", " 'regular': get_regular,\n", " 'circles': get_circles,\n", " 'moons': get_moons,\n", " 'spiral': get_spiral,\n", " 'noise': get_noise,\n", " 'anisotropic': get_anisotropic,\n", " 'varied': get_varied,\n", "}\n", "\n", "\n", "def get_groundtruth_model(X, labels, n_clusters, **kwargs):\n", " # dummy model to show true label distribution\n", " class Dummy:\n", " def __init__(self, y):\n", " self.labels_ = labels\n", "\n", " return Dummy(labels)\n", "\n", "\n", "def get_kmeans(X, labels, n_clusters, **kwargs):\n", " model = KMeans(init=\"k-means++\", n_clusters=n_clusters, n_init=10, random_state=SEED)\n", " model.set_params(**kwargs)\n", " return model.fit(X)\n", "\n", "\n", "def get_dbscan(X, labels, n_clusters, **kwargs):\n", " model = DBSCAN(eps=0.3)\n", " model.set_params(**kwargs)\n", " return model.fit(X)\n", "\n", "\n", "def get_agglomerative(X, labels, n_clusters, **kwargs):\n", " connectivity = kneighbors_graph(\n", " X, n_neighbors=n_clusters, include_self=False\n", " )\n", " # make connectivity symmetric\n", " connectivity = 0.5 * (connectivity + connectivity.T)\n", " model = AgglomerativeClustering(\n", " n_clusters=n_clusters, linkage=\"ward\", connectivity=connectivity\n", " )\n", " model.set_params(**kwargs)\n", " return model.fit(X)\n", "\n", "\n", "def get_meanshift(X, labels, n_clusters, **kwargs):\n", " bandwidth = estimate_bandwidth(X, quantile=0.25)\n", " model = MeanShift(bandwidth=bandwidth, bin_seeding=True)\n", " model.set_params(**kwargs)\n", " return model.fit(X)\n", "\n", "\n", "def get_spectral(X, labels, n_clusters, **kwargs):\n", " model = SpectralClustering(\n", " n_clusters=n_clusters,\n", " eigen_solver=\"arpack\",\n", " affinity=\"nearest_neighbors\",\n", " )\n", " model.set_params(**kwargs)\n", " return model.fit(X)\n", "\n", "\n", "def get_optics(X, labels, n_clusters, **kwargs):\n", " model = OPTICS(\n", " min_samples=7,\n", " xi=0.05,\n", " min_cluster_size=0.1,\n", " )\n", " model.set_params(**kwargs)\n", " return model.fit(X)\n", "\n", "\n", "def get_birch(X, labels, n_clusters, **kwargs):\n", " model = Birch(n_clusters=n_clusters)\n", " model.set_params(**kwargs)\n", " return model.fit(X)\n", "\n", "\n", "def get_gaussianmixture(X, labels, n_clusters, **kwargs):\n", " model = GaussianMixture(\n", " n_components=n_clusters, covariance_type=\"full\", random_state=SEED,\n", " )\n", " model.set_params(**kwargs)\n", " return model.fit(X)\n", "\n", "\n", "MODEL_MAPPING = {\n", " 'True labels': get_groundtruth_model,\n", " 'KMeans': get_kmeans,\n", " 'DBSCAN': get_dbscan,\n", " 'MeanShift': get_meanshift,\n", " 'SpectralClustering': get_spectral,\n", " 'OPTICS': get_optics,\n", " 'Birch': get_birch,\n", " 'GaussianMixture': get_gaussianmixture,\n", " 'AgglomerativeClustering': get_agglomerative,\n", "}\n", "\n", "\n", "def plot_clusters(ax, X, labels):\n", " set_clusters = set(labels)\n", " set_clusters.discard(-1) # -1 signifiies outliers, which we plot separately\n", " for label, color in zip(sorted(set_clusters), COLORS):\n", " idx = labels == label\n", " if not sum(idx):\n", " continue\n", " ax.scatter(X[idx, 0], X[idx, 1], color=color)\n", "\n", " # show outliers (if any)\n", " idx = labels == -1\n", " if sum(idx):\n", " ax.scatter(X[idx, 0], X[idx, 1], c='k', marker='x')\n", "\n", " ax.grid(None)\n", " ax.set_xticks([])\n", " ax.set_yticks([])\n", " return ax\n", "\n", "\n", "def cluster(dataset: str, n_clusters: int, clustering_algorithm: str):\n", " if isinstance(n_clusters, dict):\n", " n_clusters = n_clusters['value']\n", " else:\n", " n_clusters = int(n_clusters)\n", "\n", " X, labels = DATA_MAPPING[dataset](n_clusters)\n", " model = MODEL_MAPPING[clustering_algorithm](X, labels, n_clusters=n_clusters)\n", " if hasattr(model, \"labels_\"):\n", " y_pred = model.labels_.astype(int)\n", " else:\n", " y_pred = model.predict(X)\n", "\n", " fig, ax = plt.subplots(figsize=FIGSIZE)\n", "\n", " plot_clusters(ax, X, y_pred)\n", " ax.set_title(clustering_algorithm, fontsize=16)\n", "\n", " return fig\n", "\n", "\n", "title = \"Clustering with Scikit-learn\"\n", "description = (\n", " \"This example shows how different clustering algorithms work. Simply pick \"\n", " \"the dataset and the number of clusters to see how the clustering algorithms work. \"\n", " \"Colored circles are (predicted) labels and black x are outliers.\"\n", ")\n", "\n", "\n", "def iter_grid(n_rows, n_cols):\n", " # create a grid using gradio Block\n", " for _ in range(n_rows):\n", " with gr.Row():\n", " for _ in range(n_cols):\n", " with gr.Column():\n", " yield\n", "\n", "with gr.Blocks(title=title) as demo:\n", " gr.HTML(f\"{title}\")\n", " gr.Markdown(description)\n", "\n", " input_models = list(MODEL_MAPPING)\n", " input_data = gr.Radio(\n", " list(DATA_MAPPING),\n", " value=\"regular\",\n", " label=\"dataset\"\n", " )\n", " input_n_clusters = gr.Slider(\n", " minimum=1,\n", " maximum=MAX_CLUSTERS,\n", " value=4,\n", " step=1,\n", " label='Number of clusters'\n", " )\n", " n_rows = int(math.ceil(len(input_models) / N_COLS))\n", " counter = 0\n", " for _ in iter_grid(n_rows, N_COLS):\n", " if counter >= len(input_models):\n", " break\n", "\n", " input_model = input_models[counter]\n", " plot = gr.Plot(label=input_model)\n", " fn = partial(cluster, clustering_algorithm=input_model)\n", " input_data.change(fn=fn, inputs=[input_data, input_n_clusters], outputs=plot)\n", " input_n_clusters.change(fn=fn, inputs=[input_data, input_n_clusters], outputs=plot)\n", " counter += 1\n", "\n", "demo.launch()\n"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5} \ No newline at end of file +{"cells":[{"cell_type":"markdown","metadata":{},"source":["# Gradio Demo: clustering\n","### This demo built with Blocks generates 9 plots based on the input.\n"," "]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["!pip install -q gradio matplotlib>=3.5.2 scikit-learn>=1.0.1 "]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["import gradio as gr\n","import math\n","from functools import partial\n","import matplotlib.pyplot as plt\n","import numpy as np\n","from sklearn.cluster import (\n"," AgglomerativeClustering, Birch, DBSCAN, KMeans, MeanShift, OPTICS, SpectralClustering, estimate_bandwidth\n",")\n","from sklearn.datasets import make_blobs, make_circles, make_moons\n","from sklearn.mixture import GaussianMixture\n","from sklearn.neighbors import kneighbors_graph\n","from sklearn.preprocessing import StandardScaler\n","\n","plt.style.use('seaborn-v0_8')\n","SEED = 0\n","MAX_CLUSTERS = 10\n","N_SAMPLES = 1000\n","N_COLS = 3\n","FIGSIZE = 7, 7 # does not affect size in webpage\n","COLORS = [\n"," 'blue', 'orange', 'green', 'red', 'purple', 'brown', 'pink', 'gray', 'olive', 'cyan'\n","]\n","if len(COLORS) <= MAX_CLUSTERS:\n"," raise ValueError(\"Not enough different colors for all clusters\")\n","np.random.seed(SEED)\n","\n","\n","def normalize(X):\n"," return StandardScaler().fit_transform(X)\n","\n","\n","def get_regular(n_clusters):\n"," # spiral pattern\n"," centers = [\n"," [0, 0],\n"," [1, 0],\n"," [1, 1],\n"," [0, 1],\n"," [-1, 1],\n"," [-1, 0],\n"," [-1, -1],\n"," [0, -1],\n"," [1, -1],\n"," [2, -1],\n"," ][:n_clusters]\n"," assert len(centers) == n_clusters\n"," X, labels = make_blobs(n_samples=N_SAMPLES, centers=centers,\n"," cluster_std=0.25, random_state=SEED)\n"," return normalize(X), labels\n","\n","\n","def get_circles(n_clusters):\n"," X, labels = make_circles(\n"," n_samples=N_SAMPLES, factor=0.5, noise=0.05, random_state=SEED)\n"," return normalize(X), labels\n","\n","\n","def get_moons(n_clusters):\n"," X, labels = make_moons(n_samples=N_SAMPLES, noise=0.05, random_state=SEED)\n"," return normalize(X), labels\n","\n","\n","def get_noise(n_clusters):\n"," np.random.seed(SEED)\n"," X, labels = np.random.rand(N_SAMPLES, 2), np.random.randint(\n"," 0, n_clusters, size=(N_SAMPLES,))\n"," return normalize(X), labels\n","\n","\n","def get_anisotropic(n_clusters):\n"," X, labels = make_blobs(n_samples=N_SAMPLES,\n"," centers=n_clusters, random_state=170)\n"," transformation = [[0.6, -0.6], [-0.4, 0.8]]\n"," X = np.dot(X, transformation)\n"," return X, labels\n","\n","\n","def get_varied(n_clusters):\n"," cluster_std = [1.0, 2.5, 0.5, 1.0, 2.5,\n"," 0.5, 1.0, 2.5, 0.5, 1.0][:n_clusters]\n"," assert len(cluster_std) == n_clusters\n"," X, labels = make_blobs(\n"," n_samples=N_SAMPLES, centers=n_clusters, cluster_std=cluster_std, random_state=SEED\n"," )\n"," return normalize(X), labels\n","\n","\n","def get_spiral(n_clusters):\n"," # from https://scikit-learn.org/stable/auto_examples/cluster/plot_agglomerative_clustering.html\n"," np.random.seed(SEED)\n"," t = 1.5 * np.pi * (1 + 3 * np.random.rand(1, N_SAMPLES))\n"," x = t * np.cos(t)\n"," y = t * np.sin(t)\n"," X = np.concatenate((x, y))\n"," X += 0.7 * np.random.randn(2, N_SAMPLES)\n"," X = np.ascontiguousarray(X.T)\n","\n"," labels = np.zeros(N_SAMPLES, dtype=int)\n"," return normalize(X), labels\n","\n","\n","DATA_MAPPING = {\n"," 'regular': get_regular,\n"," 'circles': get_circles,\n"," 'moons': get_moons,\n"," 'spiral': get_spiral,\n"," 'noise': get_noise,\n"," 'anisotropic': get_anisotropic,\n"," 'varied': get_varied,\n","}\n","\n","\n","def get_groundtruth_model(X, labels, n_clusters, **kwargs):\n"," # dummy model to show true label distribution\n"," class Dummy:\n"," def __init__(self, y):\n"," self.labels_ = labels\n","\n"," return Dummy(labels)\n","\n","\n","def get_kmeans(X, labels, n_clusters, **kwargs):\n"," model = KMeans(init=\"k-means++\", n_clusters=n_clusters,\n"," n_init=10, random_state=SEED)\n"," model.set_params(**kwargs)\n"," return model.fit(X)\n","\n","\n","def get_dbscan(X, labels, n_clusters, **kwargs):\n"," model = DBSCAN(eps=0.3)\n"," model.set_params(**kwargs)\n"," return model.fit(X)\n","\n","\n","def get_agglomerative(X, labels, n_clusters, **kwargs):\n"," connectivity = kneighbors_graph(\n"," X, n_neighbors=n_clusters, include_self=False\n"," )\n"," # make connectivity symmetric\n"," connectivity = 0.5 * (connectivity + connectivity.T)\n"," model = AgglomerativeClustering(\n"," n_clusters=n_clusters, linkage=\"ward\", connectivity=connectivity\n"," )\n"," model.set_params(**kwargs)\n"," return model.fit(X)\n","\n","\n","def get_meanshift(X, labels, n_clusters, **kwargs):\n"," bandwidth = estimate_bandwidth(X, quantile=0.25)\n"," model = MeanShift(bandwidth=bandwidth, bin_seeding=True)\n"," model.set_params(**kwargs)\n"," return model.fit(X)\n","\n","\n","def get_spectral(X, labels, n_clusters, **kwargs):\n"," model = SpectralClustering(\n"," n_clusters=n_clusters,\n"," eigen_solver=\"arpack\",\n"," affinity=\"nearest_neighbors\",\n"," )\n"," model.set_params(**kwargs)\n"," return model.fit(X)\n","\n","\n","def get_optics(X, labels, n_clusters, **kwargs):\n"," model = OPTICS(\n"," min_samples=7,\n"," xi=0.05,\n"," min_cluster_size=0.1,\n"," )\n"," model.set_params(**kwargs)\n"," return model.fit(X)\n","\n","\n","def get_birch(X, labels, n_clusters, **kwargs):\n"," model = Birch(n_clusters=n_clusters)\n"," model.set_params(**kwargs)\n"," return model.fit(X)\n","\n","\n","def get_gaussianmixture(X, labels, n_clusters, **kwargs):\n"," model = GaussianMixture(\n"," n_components=n_clusters, covariance_type=\"full\", random_state=SEED,\n"," )\n"," model.set_params(**kwargs)\n"," return model.fit(X)\n","\n","\n","MODEL_MAPPING = {\n"," 'True labels': get_groundtruth_model,\n"," 'KMeans': get_kmeans,\n"," 'DBSCAN': get_dbscan,\n"," 'MeanShift': get_meanshift,\n"," 'SpectralClustering': get_spectral,\n"," 'OPTICS': get_optics,\n"," 'Birch': get_birch,\n"," 'GaussianMixture': get_gaussianmixture,\n"," 'AgglomerativeClustering': get_agglomerative,\n","}\n","\n","\n","def plot_clusters(ax, X, labels):\n"," set_clusters = set(labels)\n"," # -1 signifiies outliers, which we plot separately\n"," set_clusters.discard(-1)\n"," for label, color in zip(sorted(set_clusters), COLORS):\n"," idx = labels == label\n"," if not sum(idx):\n"," continue\n"," ax.scatter(X[idx, 0], X[idx, 1], color=color)\n","\n"," # show outliers (if any)\n"," idx = labels == -1\n"," if sum(idx):\n"," ax.scatter(X[idx, 0], X[idx, 1], c='k', marker='x')\n","\n"," ax.grid(None)\n"," ax.set_xticks([])\n"," ax.set_yticks([])\n"," return ax\n","\n","\n","def cluster(dataset: str, n_clusters: int, clustering_algorithm: str):\n"," if isinstance(n_clusters, dict):\n"," n_clusters = n_clusters['value']\n"," else:\n"," n_clusters = int(n_clusters)\n","\n"," X, labels = DATA_MAPPING[dataset](n_clusters)\n"," model = MODEL_MAPPING[clustering_algorithm](\n"," X, labels, n_clusters=n_clusters)\n"," if hasattr(model, \"labels_\"):\n"," y_pred = model.labels_.astype(int)\n"," else:\n"," y_pred = model.predict(X)\n","\n"," fig, ax = plt.subplots(figsize=FIGSIZE)\n","\n"," plot_clusters(ax, X, y_pred)\n"," ax.set_title(clustering_algorithm, fontsize=16)\n","\n"," return fig\n","\n","\n","title = \"Clustering with Scikit-learn\"\n","description = (\n"," \"This example shows how different clustering algorithms work. Simply pick \"\n"," \"the dataset and the number of clusters to see how the clustering algorithms work. \"\n"," \"Colored circles are (predicted) labels and black x are outliers.\"\n",")\n","\n","\n","def iter_grid(n_rows, n_cols):\n"," # create a grid using gradio Block\n"," for _ in range(n_rows):\n"," with gr.Row():\n"," for _ in range(n_cols):\n"," with gr.Column():\n"," yield\n","\n","\n","with gr.Blocks(title=title) as demo:\n"," gr.HTML(f\"{title}\")\n"," gr.Markdown(description)\n","\n"," input_models = list(MODEL_MAPPING)\n"," input_data = gr.Radio(\n"," list(DATA_MAPPING),\n"," value=\"regular\",\n"," label=\"dataset\"\n"," )\n"," input_n_clusters = gr.Slider(\n"," minimum=1,\n"," maximum=MAX_CLUSTERS,\n"," value=4,\n"," step=1,\n"," label='Number of clusters'\n"," )\n"," n_rows = int(math.ceil(len(input_models) / N_COLS))\n"," counter = 0\n"," for _ in iter_grid(n_rows, N_COLS):\n"," if counter >= len(input_models):\n"," break\n","\n"," input_model = input_models[counter]\n"," plot = gr.Plot(label=input_model)\n"," fn = partial(cluster, clustering_algorithm=input_model)\n"," input_data.change(\n"," fn=fn, inputs=[input_data, input_n_clusters], outputs=plot)\n"," input_n_clusters.change(\n"," fn=fn, inputs=[input_data, input_n_clusters], outputs=plot)\n"," counter += 1\n","\n","demo.launch()"]}],"metadata":{"language_info":{"name":"python"}},"nbformat":4,"nbformat_minor":5} diff --git a/demo/clustering/run.py b/demo/clustering/run.py index e3a67c00c2..66a6f301f3 100644 --- a/demo/clustering/run.py +++ b/demo/clustering/run.py @@ -20,7 +20,8 @@ FIGSIZE = 7, 7 # does not affect size in webpage COLORS = [ 'blue', 'orange', 'green', 'red', 'purple', 'brown', 'pink', 'gray', 'olive', 'cyan' ] -assert len(COLORS) >= MAX_CLUSTERS, "Not enough different colors for all clusters" +if len(COLORS) <= MAX_CLUSTERS: + raise ValueError("Not enough different colors for all clusters") np.random.seed(SEED) diff --git a/gradio/blocks.py b/gradio/blocks.py index 074d965c4a..afc3335e1b 100644 --- a/gradio/blocks.py +++ b/gradio/blocks.py @@ -44,6 +44,7 @@ from gradio.exceptions import ( DuplicateBlockError, InvalidApiNameError, InvalidBlockError, + InvalidComponentError, ) from gradio.helpers import EventData, create_tracker, skip, special_args from gradio.state_holder import SessionState @@ -364,9 +365,10 @@ def postprocess_update_dict(block: Block, update_dict: dict, postprocess: bool = attr_dict["__type__"] = "update" attr_dict.pop("value", None) if "value" in update_dict: - assert isinstance( - block, components.IOComponent - ), f"Component {block.__class__} does not support value" + if not isinstance(block, components.IOComponent): + raise InvalidComponentError( + f"Component {block.__class__} does not support value" + ) if postprocess: attr_dict["value"] = block.postprocess(update_dict["value"]) else: @@ -766,9 +768,10 @@ class Blocks(BlockContext): children = child_config.get("children") if children is not None: - assert isinstance( - block, BlockContext - ), f"Invalid config, Block with id {id} has children but is not a BlockContext." + if not isinstance(block, BlockContext): + raise ValueError( + f"Invalid config, Block with id {id} has children but is not a BlockContext." + ) with block: iterate_over_children(children) @@ -1158,7 +1161,8 @@ class Blocks(BlockContext): event_data: data associated with event trigger """ block_fn = self.fns[fn_index] - assert block_fn.fn, f"function with index {fn_index} not defined." + if not block_fn.fn: + raise IndexError(f"function with index {fn_index} not defined.") is_generating = False request = requests[0] if isinstance(requests, list) else requests start = time.time() @@ -1234,9 +1238,10 @@ class Blocks(BlockContext): raise InvalidBlockError( f"Input component with id {input_id} used in {dependency['trigger']}() event is not defined in this gr.Blocks context. You are allowed to nest gr.Blocks contexts, but there must be a gr.Blocks context that contains all components and events." ) from e - assert isinstance( - block, components.IOComponent - ), f"{block.__class__} Component with id {input_id} not a valid input component." + if not isinstance(block, components.IOComponent): + raise InvalidComponentError( + f"{block.__class__} Component with id {input_id} not a valid input component." + ) serialized_input = block.serialize(inputs[i]) processed_input.append(serialized_input) @@ -1253,9 +1258,10 @@ class Blocks(BlockContext): raise InvalidBlockError( f"Output component with id {output_id} used in {dependency['trigger']}() event not found in this gr.Blocks context. You are allowed to nest gr.Blocks contexts, but there must be a gr.Blocks context that contains all components and events." ) from e - assert isinstance( - block, components.IOComponent - ), f"{block.__class__} Component with id {output_id} not a valid output component." + if not isinstance(block, components.IOComponent): + raise InvalidComponentError( + f"{block.__class__} Component with id {output_id} not a valid output component." + ) deserialized = block.deserialize( outputs[o], save_dir=block.DEFAULT_TEMP_DIR, @@ -1322,9 +1328,10 @@ Received inputs: raise InvalidBlockError( f"Input component with id {input_id} used in {dependency['trigger']}() event not found in this gr.Blocks context. You are allowed to nest gr.Blocks contexts, but there must be a gr.Blocks context that contains all components and events." ) from e - assert isinstance( - block, components.Component - ), f"{block.__class__} Component with id {input_id} not a valid input component." + if not isinstance(block, components.Component): + raise InvalidComponentError( + f"{block.__class__} Component with id {input_id} not a valid input component." + ) if getattr(block, "stateful", False): processed_input.append(state[input_id]) else: @@ -1445,9 +1452,10 @@ Received outputs: postprocess=block_fn.postprocess, ) elif block_fn.postprocess: - assert isinstance( - block, components.Component - ), f"{block.__class__} Component with id {output_id} not a valid output component." + if not isinstance(block, components.Component): + raise InvalidComponentError( + f"{block.__class__} Component with id {output_id} not a valid output component." + ) prediction_value = block.postprocess(prediction_value) output.append(prediction_value) @@ -2005,9 +2013,8 @@ Received outputs: ) if self.is_running: - assert isinstance( - self.local_url, str - ), f"Invalid local_url: {self.local_url}" + if not isinstance(self.local_url, str): + raise ValueError(f"Invalid local_url: {self.local_url}") if not (quiet): print( "Rerunning server... use `close()` to stop if you need to change `launch()` parameters.\n----" diff --git a/gradio/components/chatbot.py b/gradio/components/chatbot.py index c61f226c00..447d42d99c 100644 --- a/gradio/components/chatbot.py +++ b/gradio/components/chatbot.py @@ -205,12 +205,14 @@ class Chatbot(Changeable, Selectable, Likeable, IOComponent, JSONSerializable): return y processed_messages = [] for message_pair in y: - assert isinstance( - message_pair, (tuple, list) - ), f"Expected a list of lists or list of tuples. Received: {message_pair}" - assert ( - len(message_pair) == 2 - ), f"Expected a list of lists of length 2 or list of tuples of length 2. Received: {message_pair}" + if not isinstance(message_pair, (tuple, list)): + raise TypeError( + f"Expected a list of lists or list of tuples. Received: {message_pair}" + ) + if len(message_pair) != 2: + raise TypeError( + f"Expected a list of lists of length 2 or list of tuples of length 2. Received: {message_pair}" + ) processed_messages.append( [ self._preprocess_chat_messages(message_pair[0]), @@ -259,12 +261,14 @@ class Chatbot(Changeable, Selectable, Likeable, IOComponent, JSONSerializable): return [] processed_messages = [] for message_pair in y: - assert isinstance( - message_pair, (tuple, list) - ), f"Expected a list of lists or list of tuples. Received: {message_pair}" - assert ( - len(message_pair) == 2 - ), f"Expected a list of lists of length 2 or list of tuples of length 2. Received: {message_pair}" + if not isinstance(message_pair, (tuple, list)): + raise TypeError( + f"Expected a list of lists or list of tuples. Received: {message_pair}" + ) + if len(message_pair) != 2: + raise TypeError( + f"Expected a list of lists of length 2 or list of tuples of length 2. Received: {message_pair}" + ) processed_messages.append( [ self._postprocess_chat_messages(message_pair[0]), diff --git a/gradio/components/code.py b/gradio/components/code.py index 2b3bd170ca..36b3264fef 100644 --- a/gradio/components/code.py +++ b/gradio/components/code.py @@ -81,7 +81,9 @@ class Code(Changeable, Inputable, IOComponent, StringSerializable): elem_id: An optional string that is assigned as the id of this component in the HTML DOM. Can be used for targeting CSS styles. elem_classes: An optional list of strings that are assigned as the classes of this component in the HTML DOM. Can be used for targeting CSS styles. """ - assert language in Code.languages, f"Language {language} not supported." + if language not in Code.languages: + raise ValueError(f"Language {language} not supported.") + self.language = language self.lines = lines IOComponent.__init__( diff --git a/gradio/components/dataframe.py b/gradio/components/dataframe.py index d508c2d1d5..b53d48ff94 100644 --- a/gradio/components/dataframe.py +++ b/gradio/components/dataframe.py @@ -280,7 +280,8 @@ class Dataframe(Changeable, Inputable, Selectable, IOComponent, JSONSerializable return self.postprocess([[]]) if isinstance(y, np.ndarray): y = y.tolist() - assert isinstance(y, list), "output cannot be converted to list" + if not isinstance(y, list): + raise ValueError("output cannot be converted to list") _headers = self.headers if len(self.headers) < len(y[0]): diff --git a/gradio/components/dataset.py b/gradio/components/dataset.py index 62c2b8b09c..cf4fa30834 100644 --- a/gradio/components/dataset.py +++ b/gradio/components/dataset.py @@ -68,9 +68,10 @@ class Dataset(Clickable, Selectable, Component, StringSerializable): self._components = [get_component_instance(c) for c in components] # Narrow type to IOComponent - assert all( - isinstance(c, IOComponent) for c in self._components - ), "All components in a `Dataset` must be subclasses of `IOComponent`" + if not all(isinstance(c, IOComponent) for c in self._components): + raise ValueError( + "All components in a `Dataset` must be subclasses of `IOComponent`" + ) self._components = [c for c in self._components if isinstance(c, IOComponent)] for component in self._components: component.root_url = self.root_url diff --git a/gradio/components/video.py b/gradio/components/video.py index 6845584586..5b1d3e5a74 100644 --- a/gradio/components/video.py +++ b/gradio/components/video.py @@ -193,14 +193,16 @@ class Video( ) if is_file: - assert file_name is not None, "Received file data without a file name." + if file_name is None: + raise ValueError("Received file data without a file name.") if client_utils.is_http_url_like(file_name): fn = self.download_temp_copy_if_needed else: fn = self.make_temp_copy_if_needed file_name = Path(fn(file_name)) else: - assert file_data is not None, "Received empty file data." + if file_data is None: + raise ValueError("Received empty file data.") file_name = Path(self.base64_to_temp_file_if_needed(file_data, file_name)) uploaded_format = file_name.suffix.replace(".", "") @@ -270,12 +272,15 @@ class Video( if isinstance(y, (str, Path)): processed_files = (self._format_video(y), None) elif isinstance(y, (tuple, list)): - assert ( - len(y) == 2 - ), f"Expected lists of length 2 or tuples of length 2. Received: {y}" - assert isinstance(y[0], (str, Path)) and isinstance( - y[1], (str, Path) - ), f"If a tuple is provided, both elements must be strings or Path objects. Received: {y}" + if len(y) != 2: + raise ValueError( + f"Expected lists of length 2 or tuples of length 2. Received: {y}" + ) + + if not (isinstance(y[0], (str, Path)) and isinstance(y[1], (str, Path))): + raise TypeError( + f"If a tuple is provided, both elements must be strings or Path objects. Received: {y}" + ) video = y[0] subtitle = y[1] processed_files = ( diff --git a/gradio/exceptions.py b/gradio/exceptions.py index bca214460e..b8ed0989ae 100644 --- a/gradio/exceptions.py +++ b/gradio/exceptions.py @@ -9,12 +9,30 @@ class DuplicateBlockError(ValueError): pass +class InvalidComponentError(ValueError): + """Raised when invalid components are used.""" + + pass + + class TooManyRequestsError(Exception): """Raised when the Hugging Face API returns a 429 status code.""" pass +class ModelNotFoundError(Exception): + """Raised when the provided model doesn't exists or is not found by the provided api url.""" + + pass + + +class RenderError(Exception): + """Raised when a component has not been rendered in the current Blocks but is expected to have been rendered.""" + + pass + + class InvalidApiNameError(ValueError): pass diff --git a/gradio/external.py b/gradio/external.py index 88710ca7c4..edf08b3d42 100644 --- a/gradio/external.py +++ b/gradio/external.py @@ -16,7 +16,7 @@ import gradio from gradio import components, utils from gradio.context import Context from gradio.deprecation import warn_deprecation -from gradio.exceptions import Error, TooManyRequestsError +from gradio.exceptions import Error, ModelNotFoundError, TooManyRequestsError from gradio.external_utils import ( cols_to_rows, encode_to_base64, @@ -83,9 +83,10 @@ def load_blocks_from_repo( if src is None: # Separate the repo type (e.g. "model") from repo name (e.g. "google/vit-base-patch16-224") tokens = name.split("/") - assert ( - len(tokens) > 1 - ), "Either `src` parameter must be provided, or `name` must be formatted as {src}/{repo name}" + if len(tokens) <= 1: + raise ValueError( + "Either `src` parameter must be provided, or `name` must be formatted as {src}/{repo name}" + ) src = tokens[0] name = "/".join(tokens[1:]) @@ -95,9 +96,8 @@ def load_blocks_from_repo( "models": from_model, "spaces": from_spaces, } - assert ( - src.lower() in factory_methods - ), f"parameter: src must be one of {factory_methods.keys()}" + if src.lower() not in factory_methods: + raise ValueError(f"parameter: src must be one of {factory_methods.keys()}") if hf_token is not None: if Context.hf_token is not None and Context.hf_token != hf_token: @@ -145,9 +145,10 @@ def from_model(model_name: str, hf_token: str | None, alias: str | None, **kwarg # Checking if model exists, and if so, it gets the pipeline response = requests.request("GET", api_url, headers=headers) - assert ( - response.status_code == 200 - ), f"Could not find model: {model_name}. If it is a private or gated model, please provide your Hugging Face access token (https://huggingface.co/settings/tokens) as the argument for the `api_key` parameter." + if response.status_code != 200: + raise ModelNotFoundError( + f"Could not find model: {model_name}. If it is a private or gated model, please provide your Hugging Face access token (https://huggingface.co/settings/tokens) as the argument for the `api_key` parameter." + ) p = response.json().get("pipeline_tag") pipelines = { "audio-classification": { diff --git a/gradio/helpers.py b/gradio/helpers.py index 79b6d2f6d2..70df46d9f9 100644 --- a/gradio/helpers.py +++ b/gradio/helpers.py @@ -525,7 +525,8 @@ class Progress(Iterable): ): current_iterable = self.iterables.pop() callback(self.iterables) - assert current_iterable.index is not None, "Index not set." + if current_iterable.index is None: + raise IndexError("Index not set.") current_iterable.index += 1 try: return next(current_iterable.iterable) # type: ignore @@ -603,7 +604,8 @@ class Progress(Iterable): callback = self._progress_callback() if callback and len(self.iterables) > 0: current_iterable = self.iterables[-1] - assert current_iterable.index is not None, "Index not set." + if current_iterable.index is None: + raise IndexError("Index not set.") current_iterable.index += n callback(self.iterables) else: diff --git a/gradio/interface.py b/gradio/interface.py index d2db6b78c6..b2fe192eb8 100644 --- a/gradio/interface.py +++ b/gradio/interface.py @@ -28,6 +28,7 @@ from gradio.components import ( from gradio.data_classes import InterfaceTypes from gradio.deprecation import warn_deprecation from gradio.events import Changeable, Streamable, Submittable, on +from gradio.exceptions import RenderError from gradio.flagging import CSVLogger, FlaggingCallback, FlagMethod from gradio.layouts import Column, Row, Tab, Tabs from gradio.pipelines import load_from_pipeline @@ -449,7 +450,8 @@ class Interface(Blocks): stop_btn = stop_btn or stop_btn_2_out flag_btns = flag_btns or flag_btns_out - assert clear_btn is not None, "Clear button not rendered" + if clear_btn is None: + raise RenderError("Clear button not rendered") self.attach_submit_events(submit_btn, stop_btn) self.attach_clear_events( @@ -586,7 +588,8 @@ class Interface(Blocks): if self.allow_flagging == "manual": flag_btns = self.render_flag_btns() elif self.allow_flagging == "auto": - assert submit_btn is not None, "Submit button not rendered" + if submit_btn is None: + raise RenderError("Submit button not rendered") flag_btns = [submit_btn] if self.interpretation: @@ -611,7 +614,8 @@ class Interface(Blocks): def attach_submit_events(self, submit_btn: Button | None, stop_btn: Button | None): if self.live: if self.interface_type == InterfaceTypes.OUTPUT_ONLY: - assert submit_btn is not None, "Submit button not rendered" + if submit_btn is None: + raise RenderError("Submit button not rendered") super().load(self.fn, None, self.output_components) # For output-only interfaces, the user probably still want a "generate" # button even if the Interface is live @@ -642,7 +646,8 @@ class Interface(Blocks): postprocess=not (self.api_mode), ) else: - assert submit_btn is not None, "Submit button not rendered" + if submit_btn is None: + raise RenderError("Submit button not rendered") fn = self.fn extra_output = [] diff --git a/gradio/interpretation.py b/gradio/interpretation.py index 767ad641b9..3731a9d381 100644 --- a/gradio/interpretation.py +++ b/gradio/interpretation.py @@ -230,7 +230,8 @@ async def run_interpret(interface: Interface, raw_input: list): nsamples=int(interface.num_shap * num_total_segments), silent=True, ) - assert shap_values is not None, "SHAP values could not be calculated" + if shap_values is None: + raise ValueError("SHAP values could not be calculated") scores.append( input_component.get_interpretation_scores( raw_input[i], diff --git a/gradio/queueing.py b/gradio/queueing.py index 54be0d116f..5bcc51e1ba 100644 --- a/gradio/queueing.py +++ b/gradio/queueing.py @@ -368,7 +368,8 @@ class Queue: async def call_prediction(self, events: list[Event], batch: bool): body = events[0].data - assert body is not None, "No event data" + if body is None: + raise ValueError("No event data") username = events[0].username body.event_id = events[0]._id if not batch else None try: diff --git a/gradio/ranged_response.py b/gradio/ranged_response.py index 88eb696184..f488776e6c 100644 --- a/gradio/ranged_response.py +++ b/gradio/ranged_response.py @@ -57,7 +57,10 @@ class RangedFileResponse(Response): stat_result: os.stat_result | None = None, method: str | None = None, ) -> None: - assert aiofiles is not None, "'aiofiles' must be installed to use FileResponse" + if aiofiles is None: + raise ModuleNotFoundError( + "'aiofiles' must be installed to use FileResponse" + ) self.path = path self.range = range self.filename = filename diff --git a/gradio/utils.py b/gradio/utils.py index a5db6417bf..a7356397fa 100644 --- a/gradio/utils.py +++ b/gradio/utils.py @@ -339,11 +339,11 @@ def assert_configs_are_equivalent_besides_ids( config2 = json.loads(json.dumps(config2)) for key in root_keys: - assert config1[key] == config2[key], f"Configs have different: {key}" + if config1[key] != config2[key]: + raise ValueError(f"Configs have different: {key}") - assert len(config1["components"]) == len( - config2["components"] - ), "# of components are different" + if len(config1["components"]) != len(config2["components"]): + raise ValueError("# of components are different") def assert_same_components(config1_id, config2_id): c1 = list(filter(lambda c: c["id"] == config1_id, config1["components"])) @@ -358,7 +358,8 @@ def assert_configs_are_equivalent_besides_ids( c1.pop("id") c2 = copy.deepcopy(c2) c2.pop("id") - assert c1 == c2, f"{c1} does not match {c2}" + if c1 != c2: + raise ValueError(f"{c1} does not match {c2}") def same_children_recursive(children1, chidren2): for child1, child2 in zip(children1, chidren2): @@ -378,7 +379,8 @@ def assert_configs_are_equivalent_besides_ids( for o1, o2 in zip(d1.pop("outputs"), d2.pop("outputs")): assert_same_components(o1, o2) - assert d1 == d2, f"{d1} does not match {d2}" + if d1 != d2: + raise ValueError(f"{d1} does not match {d2}") return True diff --git a/test/test_utils.py b/test/test_utils.py index 45da2ffe3d..533afe341b 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -115,7 +115,7 @@ def test_assert_configs_are_equivalent(): assert assert_configs_are_equivalent_besides_ids(xray_config, xray_config) assert assert_configs_are_equivalent_besides_ids(xray_config, xray_config_diff_ids) - with pytest.raises(AssertionError): + with pytest.raises(ValueError): assert_configs_are_equivalent_besides_ids(xray_config, xray_config_wrong)