Merge pull request #535 from gradio-app/AK391/master

Ak391/master
This commit is contained in:
Abubakar Abid 2022-02-03 11:24:28 -05:00 committed by GitHub
commit ac22485d33
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 42 additions and 3 deletions

View File

@ -2,6 +2,7 @@ import base64
import json
import re
import tempfile
from pydantic import MissingError
import requests
@ -24,13 +25,27 @@ def get_huggingface_interface(model_name, api_key, alias):
p = response.json().get("pipeline_tag")
def encode_to_base64(r: requests.Response) -> str:
# Handles the different ways HF API returns the prediction
base64_repr = base64.b64encode(r.content).decode("utf-8")
data_prefix = ";base64,"
# Case 1: base64 representation already includes data prefix
if data_prefix in base64_repr:
return base64_repr
else:
content_type = r.headers.get("content-type")
return "data:{};base64,".format(content_type) + base64_repr
# Case 2: the data prefix is a key in the response
if content_type == "application/json":
try:
content_type = r.json()[0]["content-type"]
base64_repr = r.json()[0]["blob"]
except KeyError:
raise ValueError("Cannot determine content type returned"
"by external API.")
# Case 3: the data prefix is included in the response headers
else:
pass
new_base64 = "data:{};base64,".format(content_type) + base64_repr
return new_base64
pipelines = {
"audio-classification": {
@ -44,6 +59,15 @@ def get_huggingface_interface(model_name, api_key, alias):
i["label"].split(", ")[0]: i["score"] for i in r.json()
},
},
"audio-to-audio": {
# example model: https://hf.co/speechbrain/mtl-mimic-voicebank
"inputs": inputs.Audio(label="Input", source="upload", type="filepath"),
"outputs": outputs.Audio(label="Output"),
"preprocess": lambda i: base64.b64decode(
i["data"].split(",")[1]
), # convert the base64 representation to binary
"postprocess": encode_to_base64,
},
"automatic-speech-recognition": {
# example model: https://hf.co/jonatasgrosman/wav2vec2-large-xlsr-53-english
"inputs": inputs.Audio(label="Input", source="upload", type="filepath"),
@ -183,7 +207,7 @@ def get_huggingface_interface(model_name, api_key, alias):
}
if p is None or not (p in pipelines):
raise ValueError("Unsupported pipeline type: {}".format(type(p)))
raise ValueError("Unsupported pipeline type: {}".format(p))
pipeline = pipelines[p]
@ -273,7 +297,10 @@ def get_spaces_interface(model_name, api_key, alias):
result = re.search(
"window.gradio_config = (.*?);</script>", r.text
) # some basic regex to extract the config
config = json.loads(result.group(1))
try:
config = json.loads(result.group(1))
except AttributeError:
raise ValueError("Could not load the Space: {}".format(model_name))
interface_info = interface_params_from_config(config)
# The function should call the API with preprocessed data

View File

@ -16,6 +16,18 @@ os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"
class TestHuggingFaceModelAPI(unittest.TestCase):
def test_audio_to_audio(self):
model_type = "audio-to-audio"
interface_info = gr.external.get_huggingface_interface(
"speechbrain/mtl-mimic-voicebank",
api_key=None,
alias=model_type,
)
self.assertEqual(interface_info["fn"].__name__, model_type)
self.assertIsInstance(interface_info["inputs"], gr.inputs.Audio)
self.assertIsInstance(interface_info["outputs"], gr.outputs.Audio)
def test_question_answering(self):
model_type = "question-answering"
interface_info = gr.external.get_huggingface_interface(