diff --git a/gradio/external.py b/gradio/external.py index e27530488d..73d26851a5 100644 --- a/gradio/external.py +++ b/gradio/external.py @@ -1,18 +1,27 @@ """This module should not be used directly as its API is subject to change. Instead, use the `gr.Blocks.load()` or `gr.Interface.load()` functions.""" +from __future__ import annotations + import base64 import json +import math +import numbers import operator import re +import warnings from copy import deepcopy -from typing import Callable, Dict +from typing import TYPE_CHECKING, Callable, Dict, List, Tuple import requests +import yaml import gradio from gradio import components, utils +if TYPE_CHECKING: + from gradio.components import DataframeData + class TooManyRequestsError(Exception): """Raised when the Hugging Face API returns a 429 status code.""" @@ -42,6 +51,58 @@ def load_blocks_from_repo(name, src=None, api_key=None, alias=None, **kwargs): return blocks +def get_tabular_examples(model_name) -> Dict[str, List[float]]: + readme = requests.get(f"https://huggingface.co/{model_name}/resolve/main/README.md") + if readme.status_code != 200: + warnings.warn(f"Cannot load examples from README for {model_name}", UserWarning) + example_data = {} + else: + yaml_regex = re.search( + "(?:^|[\r\n])---[\n\r]+([\\S\\s]*?)[\n\r]+---([\n\r]|$)", readme.text + ) + example_yaml = next(yaml.safe_load_all(readme.text[: yaml_regex.span()[-1]])) + example_data = example_yaml.get("widget", {}).get("structuredData", {}) + if not example_data: + raise ValueError( + f"No example data found in README.md of {model_name} - Cannot build gradio demo. " + "See the README.md here: https://huggingface.co/scikit-learn/tabular-playground/blob/main/README.md " + "for a reference on how to provide example data to your model." + ) + # replace nan with string NaN for inference API + for data in example_data.values(): + for i, val in enumerate(data): + if isinstance(val, numbers.Number) and math.isnan(val): + data[i] = "NaN" + return example_data + + +def cols_to_rows( + example_data: Dict[str, List[float]] +) -> Tuple[List[str], List[List[float]]]: + headers = list(example_data.keys()) + n_rows = max(len(example_data[header] or []) for header in headers) + data = [] + for row_index in range(n_rows): + row_data = [] + for header in headers: + col = example_data[header] or [] + if row_index >= len(col): + row_data.append("NaN") + else: + row_data.append(col[row_index]) + data.append(row_data) + return headers, data + + +def rows_to_cols( + incoming_data: DataframeData, +) -> Dict[str, Dict[str, Dict[str, List[str]]]]: + data_column_wise = {} + for i, header in enumerate(incoming_data["headers"]): + data_column_wise[header] = [str(row[i]) for row in incoming_data["data"]] + return {"inputs": {"data": data_column_wise}} + + def get_models_interface(model_name, api_key, alias, **kwargs): model_url = "https://huggingface.co/{}".format(model_name) api_url = "https://api-inference.huggingface.co/models/{}".format(model_name) @@ -260,6 +321,29 @@ def get_models_interface(model_name, api_key, alias, **kwargs): }, } + if p in ["tabular-classification", "tabular-regression"]: + example_data = get_tabular_examples(model_name) + col_names, example_data = cols_to_rows(example_data) + example_data = [[example_data]] if example_data else None + + pipelines[p] = { + "inputs": components.Dataframe( + label="Input Rows", + type="pandas", + headers=col_names, + col_count=(len(col_names), "fixed"), + ), + "outputs": components.Dataframe( + label="Predictions", type="array", headers=["prediction"] + ), + "preprocess": rows_to_cols, + "postprocess": lambda r: { + "headers": ["prediction"], + "data": [[pred] for pred in json.loads(r.text)], + }, + "examples": example_data, + } + if p is None or not (p in pipelines): raise ValueError("Unsupported pipeline type: {}".format(p)) @@ -275,10 +359,16 @@ def get_models_interface(model_name, api_key, alias, **kwargs): data = json.dumps(data) response = requests.request("POST", api_url, headers=headers, data=data) if not (response.status_code == 200): + errors_json = response.json() + errors, warns = "", "" + if errors_json.get("error"): + errors = f", Error: {errors_json.get('error')}" + if errors_json.get("warnings"): + warns = f", Warnings: {errors_json.get('warnings')}" raise ValueError( - "Could not complete request to HuggingFace API, Error {}".format( - response.status_code - ) + f"Could not complete request to HuggingFace API, Status Code: {response.status_code}" + + errors + + warns ) if ( p == "token-classification" @@ -299,6 +389,7 @@ def get_models_interface(model_name, api_key, alias, **kwargs): "inputs": pipeline["inputs"], "outputs": pipeline["outputs"], "title": model_name, + "examples": pipeline.get("examples"), } kwargs = dict(interface_info, **kwargs) diff --git a/requirements.txt b/requirements.txt index f567b6f9f3..f6846707ca 100644 --- a/requirements.txt +++ b/requirements.txt @@ -13,6 +13,7 @@ pillow pycryptodome python-multipart pydub +pyyaml requests uvicorn Jinja2 diff --git a/test/test_external.py b/test/test_external.py index 9c4cd72feb..2837b89e24 100644 --- a/test/test_external.py +++ b/test/test_external.py @@ -1,14 +1,16 @@ import json import os import pathlib +import textwrap import unittest -from unittest.mock import patch +from unittest.mock import MagicMock, patch import pytest import transformers import gradio as gr -from gradio.external import TooManyRequestsError +from gradio import utils +from gradio.external import TooManyRequestsError, cols_to_rows, get_tabular_examples """ WARNING: These tests have an external dependency: namely that Hugging Face's @@ -242,5 +244,92 @@ def test_interface_load_cache_examples(tmp_path): ) +def test_get_tabular_examples_replaces_nan_with_str_nan(): + readme = """ + --- + tags: + - sklearn + - skops + - tabular-classification + widget: + structuredData: + attribute_0: + - material_7 + - material_7 + - material_7 + measurement_2: + - 14.206 + - 15.094 + - .nan + --- + """ + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.text = textwrap.dedent(readme) + + with patch("gradio.external.requests.get", return_value=mock_response): + examples = get_tabular_examples("foo-model") + assert examples["measurement_2"] == [14.206, 15.094, "NaN"] + + +def test_cols_to_rows(): + assert cols_to_rows({"a": [1, 2, "NaN"], "b": [1, "NaN", 3]}) == ( + ["a", "b"], + [[1, 1], [2, "NaN"], ["NaN", 3]], + ) + assert cols_to_rows({"a": [1, 2, "NaN", 4], "b": [1, "NaN", 3]}) == ( + ["a", "b"], + [[1, 1], [2, "NaN"], ["NaN", 3], [4, "NaN"]], + ) + assert cols_to_rows({"a": [1, 2, "NaN"], "b": [1, "NaN", 3, 5]}) == ( + ["a", "b"], + [[1, 1], [2, "NaN"], ["NaN", 3], ["NaN", 5]], + ) + assert cols_to_rows({"a": None, "b": [1, "NaN", 3, 5]}) == ( + ["a", "b"], + [["NaN", 1], ["NaN", "NaN"], ["NaN", 3], ["NaN", 5]], + ) + assert cols_to_rows({"a": None, "b": None}) == (["a", "b"], []) + + +def check_dataframe(config): + input_df = next( + c for c in config["components"] if c["props"].get("label", "") == "Input Rows" + ) + assert input_df["props"]["headers"] == ["a", "b"] + assert input_df["props"]["row_count"] == (1, "dynamic") + assert input_df["props"]["col_count"] == (2, "fixed") + + +def check_dataset(config, readme_examples): + # No Examples + if not any(readme_examples.values()): + assert not any([c for c in config["components"] if c["type"] == "dataset"]) + else: + dataset = next(c for c in config["components"] if c["type"] == "dataset") + assert dataset["props"]["samples"] == [ + [utils.delete_none(cols_to_rows(readme_examples)[1])] + ] + + +@pytest.mark.parametrize( + "hypothetical_readme", + [ + {"a": [1, 2, "NaN"], "b": [1, "NaN", 3]}, + {"a": [1, 2, "NaN", 4], "b": [1, "NaN", 3]}, + {"a": [1, 2, "NaN"], "b": [1, "NaN", 3, 5]}, + {"a": None, "b": [1, "NaN", 3, 5]}, + {"a": None, "b": None}, + ], +) +def test_can_load_tabular_model_with_different_widget_data(hypothetical_readme): + with patch( + "gradio.external.get_tabular_examples", return_value=hypothetical_readme + ): + io = gr.Interface.load("models/scikit-learn/tabular-playground") + check_dataframe(io.config) + check_dataset(io.config, hypothetical_readme) + + if __name__ == "__main__": unittest.main()