Fixed audio tokenization

This commit is contained in:
Abubakar Abid 2022-01-25 16:22:46 -06:00
parent c1dc85df17
commit 19d95c460d
4 changed files with 22 additions and 11 deletions

View File

@ -6,6 +6,7 @@ automatically added to a registry, which allows them to be easily referenced in
from __future__ import annotations
import os
import json
import math
import tempfile
@ -1182,8 +1183,7 @@ class Audio(InputComponent):
return self
def tokenize(self, x):
file_obj = processing_utils.decode_base64_to_file(x)
sample_rate, data = processing_utils.audio_from_file(x)
sample_rate, data = processing_utils.audio_from_file(x["name"])
leave_one_out_sets = []
tokens = []
masks = []
@ -1193,20 +1193,27 @@ class Audio(InputComponent):
for index in range(len(boundaries) - 1):
start, stop = boundaries[index], boundaries[index + 1]
masks.append((start, stop))
# Handle the leave one outs
leave_one_out_data = np.copy(data)
leave_one_out_data[start:stop] = 0
file = tempfile.NamedTemporaryFile(delete=False)
file = tempfile.NamedTemporaryFile(delete=False, suffix=".wav")
processing_utils.audio_to_file(sample_rate, leave_one_out_data, file.name)
out_data = processing_utils.encode_file_to_base64(file.name)
leave_one_out_sets.append(out_data)
file.close()
os.unlink(file.name)
# Handle the tokens
token = np.copy(data)
token[0:start] = 0
token[stop:] = 0
file = tempfile.NamedTemporaryFile(delete=False)
file = tempfile.NamedTemporaryFile(delete=False, suffix=".wav")
processing_utils.audio_to_file(sample_rate, token, file.name)
token_data = processing_utils.encode_file_to_base64(file.name)
file.close()
os.unlink(file.name)
tokens.append(token_data)
return tokens, leave_one_out_sets, masks
@ -1215,7 +1222,7 @@ class Audio(InputComponent):
x = tokens[0]
file_obj = processing_utils.decode_base64_to_file(x)
sample_rate, data = processing_utils.audio_from_file(file_obj.name)
zero_input = np.zeros_like(data, dtype=int)
zero_input = np.zeros_like(data, dtype='int16')
# decode all of the tokens
token_data = []
for token in tokens:
@ -1229,8 +1236,10 @@ class Audio(InputComponent):
for t, b in zip(token_data, binary_mask_vector):
masked_input = masked_input + t * int(b)
file = tempfile.NamedTemporaryFile(delete=False)
processing_utils.audio_to_file(sample_rate, masked_input, file_obj.name)
processing_utils.audio_to_file(sample_rate, masked_input, file.name)
masked_data = processing_utils.encode_file_to_base64(file.name)
file.close()
os.unlink(file.name)
masked_inputs.append(masked_data)
return masked_inputs

View File

@ -17,7 +17,7 @@ import webbrowser
from logging import warning
from typing import TYPE_CHECKING, Any, Callable, List, Optional, Tuple
import markdown2 # type: ignore
import markdown2
from gradio import (encryptor, interpretation, networking, # type: ignore
queueing, strings, utils)

View File

@ -144,7 +144,7 @@ def audio_to_file(sample_rate, data, filename):
sample_width=data.dtype.itemsize,
channels=(1 if len(data.shape) == 1 else data.shape[1]),
)
audio.export(filename, format="wav")
audio.export(filename, format="wav").close()
##################

View File

@ -1,3 +1,4 @@
from difflib import SequenceMatcher
import json
import os
import tempfile
@ -567,12 +568,13 @@ class TestAudio(unittest.TestCase):
def test_tokenize(self):
x_wav = gr.test_data.BASE64_AUDIO
audio_input = gr.inputs.Audio()
audio_input = gr.inputs.Audio()
tokens, _, _ = audio_input.tokenize(x_wav)
self.assertEquals(len(tokens), audio_input.interpretation_segments)
x_new = audio_input.get_masked_inputs(tokens, [[1]*len(tokens)])[0]
self.assertEquals(x_new, x_wav)
similarity = SequenceMatcher(a=x_wav["data"], b=x_new).ratio()
self.assertGreater(similarity, 0.9)
class TestFile(unittest.TestCase):
def test_as_component(self):