mirror of
https://github.com/gradio-app/gradio.git
synced 2025-03-31 12:20:26 +08:00
Fixed audio to audio & better error messaging if Space is not loading
This commit is contained in:
parent
1893ef6140
commit
8e632b84dc
@ -2,6 +2,7 @@ import base64
|
||||
import json
|
||||
import re
|
||||
import tempfile
|
||||
from pydantic import MissingError
|
||||
|
||||
import requests
|
||||
|
||||
@ -24,13 +25,27 @@ def get_huggingface_interface(model_name, api_key, alias):
|
||||
p = response.json().get("pipeline_tag")
|
||||
|
||||
def encode_to_base64(r: requests.Response) -> str:
|
||||
# Handles the different ways HF API returns the prediction
|
||||
base64_repr = base64.b64encode(r.content).decode("utf-8")
|
||||
data_prefix = ";base64,"
|
||||
# Case 1: base64 representation already includes data prefix
|
||||
if data_prefix in base64_repr:
|
||||
return base64_repr
|
||||
else:
|
||||
content_type = r.headers.get("content-type")
|
||||
return "data:{};base64,".format(content_type) + base64_repr
|
||||
# Case 2: the data prefix is a key in the response
|
||||
if content_type == "application/json":
|
||||
try:
|
||||
content_type = r.json()[0]["content-type"]
|
||||
base64_repr = r.json()[0]["blob"]
|
||||
except KeyError:
|
||||
raise ValueError("Cannot determine content type returned"
|
||||
"by external API.")
|
||||
# Case 3: the data prefix is included in the response headers
|
||||
else:
|
||||
pass
|
||||
new_base64 = "data:{};base64,".format(content_type) + base64_repr
|
||||
return new_base64
|
||||
|
||||
pipelines = {
|
||||
"audio-classification": {
|
||||
@ -45,7 +60,7 @@ def get_huggingface_interface(model_name, api_key, alias):
|
||||
},
|
||||
},
|
||||
"audio-to-audio": {
|
||||
# example model: https://hf.co/ehcalabres/wav2vec2-lg-xlsr-en-speech-emotion-recognition
|
||||
# example model: https://hf.co/speechbrain/mtl-mimic-voicebank
|
||||
"inputs": inputs.Audio(label="Input", source="upload", type="filepath"),
|
||||
"outputs": outputs.Audio(label="Output"),
|
||||
"preprocess": lambda i: base64.b64decode(
|
||||
@ -192,7 +207,7 @@ def get_huggingface_interface(model_name, api_key, alias):
|
||||
}
|
||||
|
||||
if p is None or not (p in pipelines):
|
||||
raise ValueError("Unsupported pipeline type: {}".format(type(p)))
|
||||
raise ValueError("Unsupported pipeline type: {}".format(p))
|
||||
|
||||
pipeline = pipelines[p]
|
||||
|
||||
@ -282,7 +297,10 @@ def get_spaces_interface(model_name, api_key, alias):
|
||||
result = re.search(
|
||||
"window.gradio_config = (.*?);</script>", r.text
|
||||
) # some basic regex to extract the config
|
||||
config = json.loads(result.group(1))
|
||||
try:
|
||||
config = json.loads(result.group(1))
|
||||
except AttributeError:
|
||||
raise ValueError("Could not load the Space: {}".format(model_name))
|
||||
interface_info = interface_params_from_config(config)
|
||||
|
||||
# The function should call the API with preprocessed data
|
||||
|
@ -19,7 +19,7 @@ class TestHuggingFaceModelAPI(unittest.TestCase):
|
||||
def test_audio_to_audio(self):
|
||||
model_type = "audio-to-audio"
|
||||
interface_info = gr.external.get_huggingface_interface(
|
||||
"facebook/xm_transformer_600m-es_en-multi_domain",
|
||||
"speechbrain/mtl-mimic-voicebank",
|
||||
api_key=None,
|
||||
alias=model_type,
|
||||
)
|
||||
|
Loading…
x
Reference in New Issue
Block a user