mirror of
https://github.com/gradio-app/gradio.git
synced 2025-04-06 12:30:29 +08:00
Fixed audio tokenization
This commit is contained in:
parent
c1dc85df17
commit
19d95c460d
@ -6,6 +6,7 @@ automatically added to a registry, which allows them to be easily referenced in
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import json
|
||||
import math
|
||||
import tempfile
|
||||
@ -1182,8 +1183,7 @@ class Audio(InputComponent):
|
||||
return self
|
||||
|
||||
def tokenize(self, x):
|
||||
file_obj = processing_utils.decode_base64_to_file(x)
|
||||
sample_rate, data = processing_utils.audio_from_file(x)
|
||||
sample_rate, data = processing_utils.audio_from_file(x["name"])
|
||||
leave_one_out_sets = []
|
||||
tokens = []
|
||||
masks = []
|
||||
@ -1193,20 +1193,27 @@ class Audio(InputComponent):
|
||||
for index in range(len(boundaries) - 1):
|
||||
start, stop = boundaries[index], boundaries[index + 1]
|
||||
masks.append((start, stop))
|
||||
|
||||
# Handle the leave one outs
|
||||
leave_one_out_data = np.copy(data)
|
||||
leave_one_out_data[start:stop] = 0
|
||||
file = tempfile.NamedTemporaryFile(delete=False)
|
||||
file = tempfile.NamedTemporaryFile(delete=False, suffix=".wav")
|
||||
processing_utils.audio_to_file(sample_rate, leave_one_out_data, file.name)
|
||||
out_data = processing_utils.encode_file_to_base64(file.name)
|
||||
leave_one_out_sets.append(out_data)
|
||||
file.close()
|
||||
os.unlink(file.name)
|
||||
|
||||
# Handle the tokens
|
||||
token = np.copy(data)
|
||||
token[0:start] = 0
|
||||
token[stop:] = 0
|
||||
file = tempfile.NamedTemporaryFile(delete=False)
|
||||
file = tempfile.NamedTemporaryFile(delete=False, suffix=".wav")
|
||||
processing_utils.audio_to_file(sample_rate, token, file.name)
|
||||
token_data = processing_utils.encode_file_to_base64(file.name)
|
||||
file.close()
|
||||
os.unlink(file.name)
|
||||
|
||||
tokens.append(token_data)
|
||||
return tokens, leave_one_out_sets, masks
|
||||
|
||||
@ -1215,7 +1222,7 @@ class Audio(InputComponent):
|
||||
x = tokens[0]
|
||||
file_obj = processing_utils.decode_base64_to_file(x)
|
||||
sample_rate, data = processing_utils.audio_from_file(file_obj.name)
|
||||
zero_input = np.zeros_like(data, dtype=int)
|
||||
zero_input = np.zeros_like(data, dtype='int16')
|
||||
# decode all of the tokens
|
||||
token_data = []
|
||||
for token in tokens:
|
||||
@ -1229,8 +1236,10 @@ class Audio(InputComponent):
|
||||
for t, b in zip(token_data, binary_mask_vector):
|
||||
masked_input = masked_input + t * int(b)
|
||||
file = tempfile.NamedTemporaryFile(delete=False)
|
||||
processing_utils.audio_to_file(sample_rate, masked_input, file_obj.name)
|
||||
processing_utils.audio_to_file(sample_rate, masked_input, file.name)
|
||||
masked_data = processing_utils.encode_file_to_base64(file.name)
|
||||
file.close()
|
||||
os.unlink(file.name)
|
||||
masked_inputs.append(masked_data)
|
||||
return masked_inputs
|
||||
|
||||
|
@ -17,7 +17,7 @@ import webbrowser
|
||||
from logging import warning
|
||||
from typing import TYPE_CHECKING, Any, Callable, List, Optional, Tuple
|
||||
|
||||
import markdown2 # type: ignore
|
||||
import markdown2
|
||||
|
||||
from gradio import (encryptor, interpretation, networking, # type: ignore
|
||||
queueing, strings, utils)
|
||||
|
@ -144,7 +144,7 @@ def audio_to_file(sample_rate, data, filename):
|
||||
sample_width=data.dtype.itemsize,
|
||||
channels=(1 if len(data.shape) == 1 else data.shape[1]),
|
||||
)
|
||||
audio.export(filename, format="wav")
|
||||
audio.export(filename, format="wav").close()
|
||||
|
||||
|
||||
##################
|
||||
|
@ -1,3 +1,4 @@
|
||||
from difflib import SequenceMatcher
|
||||
import json
|
||||
import os
|
||||
import tempfile
|
||||
@ -567,12 +568,13 @@ class TestAudio(unittest.TestCase):
|
||||
|
||||
def test_tokenize(self):
|
||||
x_wav = gr.test_data.BASE64_AUDIO
|
||||
audio_input = gr.inputs.Audio()
|
||||
audio_input = gr.inputs.Audio()
|
||||
tokens, _, _ = audio_input.tokenize(x_wav)
|
||||
self.assertEquals(len(tokens), audio_input.interpretation_segments)
|
||||
x_new = audio_input.get_masked_inputs(tokens, [[1]*len(tokens)])[0]
|
||||
self.assertEquals(x_new, x_wav)
|
||||
|
||||
similarity = SequenceMatcher(a=x_wav["data"], b=x_new).ratio()
|
||||
self.assertGreater(similarity, 0.9)
|
||||
|
||||
|
||||
class TestFile(unittest.TestCase):
|
||||
def test_as_component(self):
|
||||
|
Loading…
x
Reference in New Issue
Block a user