mirror of
https://github.com/gradio-app/gradio.git
synced 2024-11-27 01:40:20 +08:00
Skops integration: Load tabular classification and regression models from the hub (#2126)
* MVP of skops integration * Add unit tests * One more case * Fix NaNs in widget data * Remove breakpoint * Fix typo
This commit is contained in:
parent
b145a2a191
commit
eb81fa2cf2
@ -1,18 +1,27 @@
|
||||
"""This module should not be used directly as its API is subject to change. Instead,
|
||||
use the `gr.Blocks.load()` or `gr.Interface.load()` functions."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import json
|
||||
import math
|
||||
import numbers
|
||||
import operator
|
||||
import re
|
||||
import warnings
|
||||
from copy import deepcopy
|
||||
from typing import Callable, Dict
|
||||
from typing import TYPE_CHECKING, Callable, Dict, List, Tuple
|
||||
|
||||
import requests
|
||||
import yaml
|
||||
|
||||
import gradio
|
||||
from gradio import components, utils
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from gradio.components import DataframeData
|
||||
|
||||
|
||||
class TooManyRequestsError(Exception):
|
||||
"""Raised when the Hugging Face API returns a 429 status code."""
|
||||
@ -42,6 +51,58 @@ def load_blocks_from_repo(name, src=None, api_key=None, alias=None, **kwargs):
|
||||
return blocks
|
||||
|
||||
|
||||
def get_tabular_examples(model_name) -> Dict[str, List[float]]:
|
||||
readme = requests.get(f"https://huggingface.co/{model_name}/resolve/main/README.md")
|
||||
if readme.status_code != 200:
|
||||
warnings.warn(f"Cannot load examples from README for {model_name}", UserWarning)
|
||||
example_data = {}
|
||||
else:
|
||||
yaml_regex = re.search(
|
||||
"(?:^|[\r\n])---[\n\r]+([\\S\\s]*?)[\n\r]+---([\n\r]|$)", readme.text
|
||||
)
|
||||
example_yaml = next(yaml.safe_load_all(readme.text[: yaml_regex.span()[-1]]))
|
||||
example_data = example_yaml.get("widget", {}).get("structuredData", {})
|
||||
if not example_data:
|
||||
raise ValueError(
|
||||
f"No example data found in README.md of {model_name} - Cannot build gradio demo. "
|
||||
"See the README.md here: https://huggingface.co/scikit-learn/tabular-playground/blob/main/README.md "
|
||||
"for a reference on how to provide example data to your model."
|
||||
)
|
||||
# replace nan with string NaN for inference API
|
||||
for data in example_data.values():
|
||||
for i, val in enumerate(data):
|
||||
if isinstance(val, numbers.Number) and math.isnan(val):
|
||||
data[i] = "NaN"
|
||||
return example_data
|
||||
|
||||
|
||||
def cols_to_rows(
|
||||
example_data: Dict[str, List[float]]
|
||||
) -> Tuple[List[str], List[List[float]]]:
|
||||
headers = list(example_data.keys())
|
||||
n_rows = max(len(example_data[header] or []) for header in headers)
|
||||
data = []
|
||||
for row_index in range(n_rows):
|
||||
row_data = []
|
||||
for header in headers:
|
||||
col = example_data[header] or []
|
||||
if row_index >= len(col):
|
||||
row_data.append("NaN")
|
||||
else:
|
||||
row_data.append(col[row_index])
|
||||
data.append(row_data)
|
||||
return headers, data
|
||||
|
||||
|
||||
def rows_to_cols(
|
||||
incoming_data: DataframeData,
|
||||
) -> Dict[str, Dict[str, Dict[str, List[str]]]]:
|
||||
data_column_wise = {}
|
||||
for i, header in enumerate(incoming_data["headers"]):
|
||||
data_column_wise[header] = [str(row[i]) for row in incoming_data["data"]]
|
||||
return {"inputs": {"data": data_column_wise}}
|
||||
|
||||
|
||||
def get_models_interface(model_name, api_key, alias, **kwargs):
|
||||
model_url = "https://huggingface.co/{}".format(model_name)
|
||||
api_url = "https://api-inference.huggingface.co/models/{}".format(model_name)
|
||||
@ -260,6 +321,29 @@ def get_models_interface(model_name, api_key, alias, **kwargs):
|
||||
},
|
||||
}
|
||||
|
||||
if p in ["tabular-classification", "tabular-regression"]:
|
||||
example_data = get_tabular_examples(model_name)
|
||||
col_names, example_data = cols_to_rows(example_data)
|
||||
example_data = [[example_data]] if example_data else None
|
||||
|
||||
pipelines[p] = {
|
||||
"inputs": components.Dataframe(
|
||||
label="Input Rows",
|
||||
type="pandas",
|
||||
headers=col_names,
|
||||
col_count=(len(col_names), "fixed"),
|
||||
),
|
||||
"outputs": components.Dataframe(
|
||||
label="Predictions", type="array", headers=["prediction"]
|
||||
),
|
||||
"preprocess": rows_to_cols,
|
||||
"postprocess": lambda r: {
|
||||
"headers": ["prediction"],
|
||||
"data": [[pred] for pred in json.loads(r.text)],
|
||||
},
|
||||
"examples": example_data,
|
||||
}
|
||||
|
||||
if p is None or not (p in pipelines):
|
||||
raise ValueError("Unsupported pipeline type: {}".format(p))
|
||||
|
||||
@ -275,10 +359,16 @@ def get_models_interface(model_name, api_key, alias, **kwargs):
|
||||
data = json.dumps(data)
|
||||
response = requests.request("POST", api_url, headers=headers, data=data)
|
||||
if not (response.status_code == 200):
|
||||
errors_json = response.json()
|
||||
errors, warns = "", ""
|
||||
if errors_json.get("error"):
|
||||
errors = f", Error: {errors_json.get('error')}"
|
||||
if errors_json.get("warnings"):
|
||||
warns = f", Warnings: {errors_json.get('warnings')}"
|
||||
raise ValueError(
|
||||
"Could not complete request to HuggingFace API, Error {}".format(
|
||||
response.status_code
|
||||
)
|
||||
f"Could not complete request to HuggingFace API, Status Code: {response.status_code}"
|
||||
+ errors
|
||||
+ warns
|
||||
)
|
||||
if (
|
||||
p == "token-classification"
|
||||
@ -299,6 +389,7 @@ def get_models_interface(model_name, api_key, alias, **kwargs):
|
||||
"inputs": pipeline["inputs"],
|
||||
"outputs": pipeline["outputs"],
|
||||
"title": model_name,
|
||||
"examples": pipeline.get("examples"),
|
||||
}
|
||||
|
||||
kwargs = dict(interface_info, **kwargs)
|
||||
|
@ -13,6 +13,7 @@ pillow
|
||||
pycryptodome
|
||||
python-multipart
|
||||
pydub
|
||||
pyyaml
|
||||
requests
|
||||
uvicorn
|
||||
Jinja2
|
||||
|
@ -1,14 +1,16 @@
|
||||
import json
|
||||
import os
|
||||
import pathlib
|
||||
import textwrap
|
||||
import unittest
|
||||
from unittest.mock import patch
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
import transformers
|
||||
|
||||
import gradio as gr
|
||||
from gradio.external import TooManyRequestsError
|
||||
from gradio import utils
|
||||
from gradio.external import TooManyRequestsError, cols_to_rows, get_tabular_examples
|
||||
|
||||
"""
|
||||
WARNING: These tests have an external dependency: namely that Hugging Face's
|
||||
@ -242,5 +244,92 @@ def test_interface_load_cache_examples(tmp_path):
|
||||
)
|
||||
|
||||
|
||||
def test_get_tabular_examples_replaces_nan_with_str_nan():
|
||||
readme = """
|
||||
---
|
||||
tags:
|
||||
- sklearn
|
||||
- skops
|
||||
- tabular-classification
|
||||
widget:
|
||||
structuredData:
|
||||
attribute_0:
|
||||
- material_7
|
||||
- material_7
|
||||
- material_7
|
||||
measurement_2:
|
||||
- 14.206
|
||||
- 15.094
|
||||
- .nan
|
||||
---
|
||||
"""
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.text = textwrap.dedent(readme)
|
||||
|
||||
with patch("gradio.external.requests.get", return_value=mock_response):
|
||||
examples = get_tabular_examples("foo-model")
|
||||
assert examples["measurement_2"] == [14.206, 15.094, "NaN"]
|
||||
|
||||
|
||||
def test_cols_to_rows():
|
||||
assert cols_to_rows({"a": [1, 2, "NaN"], "b": [1, "NaN", 3]}) == (
|
||||
["a", "b"],
|
||||
[[1, 1], [2, "NaN"], ["NaN", 3]],
|
||||
)
|
||||
assert cols_to_rows({"a": [1, 2, "NaN", 4], "b": [1, "NaN", 3]}) == (
|
||||
["a", "b"],
|
||||
[[1, 1], [2, "NaN"], ["NaN", 3], [4, "NaN"]],
|
||||
)
|
||||
assert cols_to_rows({"a": [1, 2, "NaN"], "b": [1, "NaN", 3, 5]}) == (
|
||||
["a", "b"],
|
||||
[[1, 1], [2, "NaN"], ["NaN", 3], ["NaN", 5]],
|
||||
)
|
||||
assert cols_to_rows({"a": None, "b": [1, "NaN", 3, 5]}) == (
|
||||
["a", "b"],
|
||||
[["NaN", 1], ["NaN", "NaN"], ["NaN", 3], ["NaN", 5]],
|
||||
)
|
||||
assert cols_to_rows({"a": None, "b": None}) == (["a", "b"], [])
|
||||
|
||||
|
||||
def check_dataframe(config):
|
||||
input_df = next(
|
||||
c for c in config["components"] if c["props"].get("label", "") == "Input Rows"
|
||||
)
|
||||
assert input_df["props"]["headers"] == ["a", "b"]
|
||||
assert input_df["props"]["row_count"] == (1, "dynamic")
|
||||
assert input_df["props"]["col_count"] == (2, "fixed")
|
||||
|
||||
|
||||
def check_dataset(config, readme_examples):
|
||||
# No Examples
|
||||
if not any(readme_examples.values()):
|
||||
assert not any([c for c in config["components"] if c["type"] == "dataset"])
|
||||
else:
|
||||
dataset = next(c for c in config["components"] if c["type"] == "dataset")
|
||||
assert dataset["props"]["samples"] == [
|
||||
[utils.delete_none(cols_to_rows(readme_examples)[1])]
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"hypothetical_readme",
|
||||
[
|
||||
{"a": [1, 2, "NaN"], "b": [1, "NaN", 3]},
|
||||
{"a": [1, 2, "NaN", 4], "b": [1, "NaN", 3]},
|
||||
{"a": [1, 2, "NaN"], "b": [1, "NaN", 3, 5]},
|
||||
{"a": None, "b": [1, "NaN", 3, 5]},
|
||||
{"a": None, "b": None},
|
||||
],
|
||||
)
|
||||
def test_can_load_tabular_model_with_different_widget_data(hypothetical_readme):
|
||||
with patch(
|
||||
"gradio.external.get_tabular_examples", return_value=hypothetical_readme
|
||||
):
|
||||
io = gr.Interface.load("models/scikit-learn/tabular-playground")
|
||||
check_dataframe(io.config)
|
||||
check_dataset(io.config, hypothetical_readme)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
Loading…
Reference in New Issue
Block a user