mirror of
https://github.com/gradio-app/gradio.git
synced 2025-04-06 12:30:29 +08:00
commit
ac22485d33
@ -2,6 +2,7 @@ import base64
|
||||
import json
|
||||
import re
|
||||
import tempfile
|
||||
from pydantic import MissingError
|
||||
|
||||
import requests
|
||||
|
||||
@ -24,13 +25,27 @@ def get_huggingface_interface(model_name, api_key, alias):
|
||||
p = response.json().get("pipeline_tag")
|
||||
|
||||
def encode_to_base64(r: requests.Response) -> str:
|
||||
# Handles the different ways HF API returns the prediction
|
||||
base64_repr = base64.b64encode(r.content).decode("utf-8")
|
||||
data_prefix = ";base64,"
|
||||
# Case 1: base64 representation already includes data prefix
|
||||
if data_prefix in base64_repr:
|
||||
return base64_repr
|
||||
else:
|
||||
content_type = r.headers.get("content-type")
|
||||
return "data:{};base64,".format(content_type) + base64_repr
|
||||
# Case 2: the data prefix is a key in the response
|
||||
if content_type == "application/json":
|
||||
try:
|
||||
content_type = r.json()[0]["content-type"]
|
||||
base64_repr = r.json()[0]["blob"]
|
||||
except KeyError:
|
||||
raise ValueError("Cannot determine content type returned"
|
||||
"by external API.")
|
||||
# Case 3: the data prefix is included in the response headers
|
||||
else:
|
||||
pass
|
||||
new_base64 = "data:{};base64,".format(content_type) + base64_repr
|
||||
return new_base64
|
||||
|
||||
pipelines = {
|
||||
"audio-classification": {
|
||||
@ -44,6 +59,15 @@ def get_huggingface_interface(model_name, api_key, alias):
|
||||
i["label"].split(", ")[0]: i["score"] for i in r.json()
|
||||
},
|
||||
},
|
||||
"audio-to-audio": {
|
||||
# example model: https://hf.co/speechbrain/mtl-mimic-voicebank
|
||||
"inputs": inputs.Audio(label="Input", source="upload", type="filepath"),
|
||||
"outputs": outputs.Audio(label="Output"),
|
||||
"preprocess": lambda i: base64.b64decode(
|
||||
i["data"].split(",")[1]
|
||||
), # convert the base64 representation to binary
|
||||
"postprocess": encode_to_base64,
|
||||
},
|
||||
"automatic-speech-recognition": {
|
||||
# example model: https://hf.co/jonatasgrosman/wav2vec2-large-xlsr-53-english
|
||||
"inputs": inputs.Audio(label="Input", source="upload", type="filepath"),
|
||||
@ -183,7 +207,7 @@ def get_huggingface_interface(model_name, api_key, alias):
|
||||
}
|
||||
|
||||
if p is None or not (p in pipelines):
|
||||
raise ValueError("Unsupported pipeline type: {}".format(type(p)))
|
||||
raise ValueError("Unsupported pipeline type: {}".format(p))
|
||||
|
||||
pipeline = pipelines[p]
|
||||
|
||||
@ -273,7 +297,10 @@ def get_spaces_interface(model_name, api_key, alias):
|
||||
result = re.search(
|
||||
"window.gradio_config = (.*?);</script>", r.text
|
||||
) # some basic regex to extract the config
|
||||
config = json.loads(result.group(1))
|
||||
try:
|
||||
config = json.loads(result.group(1))
|
||||
except AttributeError:
|
||||
raise ValueError("Could not load the Space: {}".format(model_name))
|
||||
interface_info = interface_params_from_config(config)
|
||||
|
||||
# The function should call the API with preprocessed data
|
||||
|
@ -16,6 +16,18 @@ os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"
|
||||
|
||||
|
||||
class TestHuggingFaceModelAPI(unittest.TestCase):
|
||||
def test_audio_to_audio(self):
|
||||
model_type = "audio-to-audio"
|
||||
interface_info = gr.external.get_huggingface_interface(
|
||||
"speechbrain/mtl-mimic-voicebank",
|
||||
api_key=None,
|
||||
alias=model_type,
|
||||
)
|
||||
self.assertEqual(interface_info["fn"].__name__, model_type)
|
||||
self.assertIsInstance(interface_info["inputs"], gr.inputs.Audio)
|
||||
self.assertIsInstance(interface_info["outputs"], gr.outputs.Audio)
|
||||
|
||||
|
||||
def test_question_answering(self):
|
||||
model_type = "question-answering"
|
||||
interface_info = gr.external.get_huggingface_interface(
|
||||
|
Loading…
x
Reference in New Issue
Block a user