mirror of
https://github.com/gradio-app/gradio.git
synced 2025-01-30 11:00:11 +08:00
Cache temp files created from base64 data (#3197)
* changes * added workflow * fix action * fix action * fix action * changelg * formatting * fix * Delete benchmark-queue.yml * Delete benchmark_queue.py * changelog * lint * fix tests * fix tests * fix for python 3.7 * formatting
This commit is contained in:
parent
74f75f004a
commit
752ec0ef6a
@ -55,6 +55,7 @@ No changes to highlight.
|
|||||||
|
|
||||||
## Full Changelog:
|
## Full Changelog:
|
||||||
* Fix demos page css and add close demos button by [@aliabd](https://github.com/aliabd) in [PR 3151](https://github.com/gradio-app/gradio/pull/3151)
|
* Fix demos page css and add close demos button by [@aliabd](https://github.com/aliabd) in [PR 3151](https://github.com/gradio-app/gradio/pull/3151)
|
||||||
|
* Caches temp files from base64 input data by giving them a deterministic path based on the contents of data by [@abidlabs](https://github.com/abidlabs) in [PR 3197](https://github.com/gradio-app/gradio/pull/3197)
|
||||||
* Better warnings (when there is a mismatch between the number of output components and values returned by a function, or when the `File` component or `UploadButton` component includes a `file_types` parameter along with `file_count=="dir"`) by [@abidlabs](https://github.com/abidlabs) in [PR 3194](https://github.com/gradio-app/gradio/pull/3194)
|
* Better warnings (when there is a mismatch between the number of output components and values returned by a function, or when the `File` component or `UploadButton` component includes a `file_types` parameter along with `file_count=="dir"`) by [@abidlabs](https://github.com/abidlabs) in [PR 3194](https://github.com/gradio-app/gradio/pull/3194)
|
||||||
* Raises a `gr.Error` instead of a regular Python error when you use `gr.Interface.load()` to load a model and there's an error querying the HF API by [@abidlabs](https://github.com/abidlabs) in [PR 3194](https://github.com/gradio-app/gradio/pull/3194)
|
* Raises a `gr.Error` instead of a regular Python error when you use `gr.Interface.load()` to load a model and there's an error querying the HF API by [@abidlabs](https://github.com/abidlabs) in [PR 3194](https://github.com/gradio-app/gradio/pull/3194)
|
||||||
|
|
||||||
|
@ -1808,13 +1808,9 @@ class Video(
|
|||||||
x.get("is_file", False),
|
x.get("is_file", False),
|
||||||
)
|
)
|
||||||
if is_file:
|
if is_file:
|
||||||
file = self.make_temp_copy_if_needed(file_name)
|
file_name = Path(self.make_temp_copy_if_needed(file_name))
|
||||||
file_name = Path(file)
|
|
||||||
else:
|
else:
|
||||||
file = processing_utils.decode_base64_to_file(
|
file_name = Path(self.base64_to_temp_file_if_needed(file_data, file_name))
|
||||||
file_data, file_path=file_name
|
|
||||||
)
|
|
||||||
file_name = Path(file.name)
|
|
||||||
|
|
||||||
uploaded_format = file_name.suffix.replace(".", "")
|
uploaded_format = file_name.suffix.replace(".", "")
|
||||||
modify_format = self.format is not None and uploaded_format != self.format
|
modify_format = self.format is not None and uploaded_format != self.format
|
||||||
@ -2041,20 +2037,26 @@ class Audio(
|
|||||||
else:
|
else:
|
||||||
temp_file_path = self.make_temp_copy_if_needed(file_name)
|
temp_file_path = self.make_temp_copy_if_needed(file_name)
|
||||||
else:
|
else:
|
||||||
temp_file_obj = processing_utils.decode_base64_to_file(
|
temp_file_path = self.base64_to_temp_file_if_needed(file_data, file_name)
|
||||||
file_data, file_path=file_name
|
|
||||||
)
|
|
||||||
temp_file_path = temp_file_obj.name
|
|
||||||
|
|
||||||
sample_rate, data = processing_utils.audio_from_file(
|
sample_rate, data = processing_utils.audio_from_file(
|
||||||
temp_file_path, crop_min=crop_min, crop_max=crop_max
|
temp_file_path, crop_min=crop_min, crop_max=crop_max
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Need a unique name for the file to avoid re-using the same audio file if
|
||||||
|
# a user submits the same audio file twice, but with different crop min/max.
|
||||||
|
temp_file_path = Path(temp_file_path)
|
||||||
|
output_file_name = str(
|
||||||
|
temp_file_path.with_name(
|
||||||
|
f"{temp_file_path.stem}-{crop_min}-{crop_max}{temp_file_path.suffix}"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
if self.type == "numpy":
|
if self.type == "numpy":
|
||||||
return sample_rate, data
|
return sample_rate, data
|
||||||
elif self.type == "filepath":
|
elif self.type == "filepath":
|
||||||
processing_utils.audio_to_file(sample_rate, data, temp_file_path)
|
processing_utils.audio_to_file(sample_rate, data, output_file_name)
|
||||||
return temp_file_path
|
return output_file_name
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Unknown type: "
|
"Unknown type: "
|
||||||
@ -2075,8 +2077,8 @@ class Audio(
|
|||||||
if x.get("is_file"):
|
if x.get("is_file"):
|
||||||
sample_rate, data = processing_utils.audio_from_file(x["name"])
|
sample_rate, data = processing_utils.audio_from_file(x["name"])
|
||||||
else:
|
else:
|
||||||
file_obj = processing_utils.decode_base64_to_file(x["data"])
|
file_name = self.base64_to_temp_file_if_needed(x["data"])
|
||||||
sample_rate, data = processing_utils.audio_from_file(file_obj.name)
|
sample_rate, data = processing_utils.audio_from_file(file_name)
|
||||||
leave_one_out_sets = []
|
leave_one_out_sets = []
|
||||||
tokens = []
|
tokens = []
|
||||||
masks = []
|
masks = []
|
||||||
@ -2117,14 +2119,14 @@ class Audio(
|
|||||||
def get_masked_inputs(self, tokens, binary_mask_matrix):
|
def get_masked_inputs(self, tokens, binary_mask_matrix):
|
||||||
# create a "zero input" vector and get sample rate
|
# create a "zero input" vector and get sample rate
|
||||||
x = tokens[0]["data"]
|
x = tokens[0]["data"]
|
||||||
file_obj = processing_utils.decode_base64_to_file(x)
|
file_name = self.base64_to_temp_file_if_needed(x)
|
||||||
sample_rate, data = processing_utils.audio_from_file(file_obj.name)
|
sample_rate, data = processing_utils.audio_from_file(file_name)
|
||||||
zero_input = np.zeros_like(data, dtype="int16")
|
zero_input = np.zeros_like(data, dtype="int16")
|
||||||
# decode all of the tokens
|
# decode all of the tokens
|
||||||
token_data = []
|
token_data = []
|
||||||
for token in tokens:
|
for token in tokens:
|
||||||
file_obj = processing_utils.decode_base64_to_file(token["data"])
|
file_name = self.base64_to_temp_file_if_needed(token["data"])
|
||||||
_, data = processing_utils.audio_from_file(file_obj.name)
|
_, data = processing_utils.audio_from_file(file_name)
|
||||||
token_data.append(data)
|
token_data.append(data)
|
||||||
# construct the masked version
|
# construct the masked version
|
||||||
masked_inputs = []
|
masked_inputs = []
|
||||||
@ -4046,10 +4048,7 @@ class Model3D(
|
|||||||
if is_file:
|
if is_file:
|
||||||
temp_file_path = self.make_temp_copy_if_needed(file_name)
|
temp_file_path = self.make_temp_copy_if_needed(file_name)
|
||||||
else:
|
else:
|
||||||
temp_file = processing_utils.decode_base64_to_file(
|
temp_file_path = self.base64_to_temp_file_if_needed(file_data, file_name)
|
||||||
file_data, file_path=file_name
|
|
||||||
)
|
|
||||||
temp_file_path = temp_file.name
|
|
||||||
|
|
||||||
return temp_file_path
|
return temp_file_path
|
||||||
|
|
||||||
|
@ -353,6 +353,13 @@ class TempFileManager:
|
|||||||
sha1.update(data)
|
sha1.update(data)
|
||||||
return sha1.hexdigest()
|
return sha1.hexdigest()
|
||||||
|
|
||||||
|
def hash_base64(self, base64_encoding: str, chunk_num_blocks: int = 128) -> str:
|
||||||
|
sha1 = hashlib.sha1()
|
||||||
|
for i in range(0, len(base64_encoding), chunk_num_blocks * sha1.block_size):
|
||||||
|
data = base64_encoding[i : i + chunk_num_blocks * sha1.block_size]
|
||||||
|
sha1.update(data.encode("utf-8"))
|
||||||
|
return sha1.hexdigest()
|
||||||
|
|
||||||
def get_prefix_and_extension(self, file_path_or_url: str) -> Tuple[str, str]:
|
def get_prefix_and_extension(self, file_path_or_url: str) -> Tuple[str, str]:
|
||||||
file_name = Path(file_path_or_url).name
|
file_name = Path(file_path_or_url).name
|
||||||
prefix, extension = file_name, None
|
prefix, extension = file_name, None
|
||||||
@ -374,6 +381,12 @@ class TempFileManager:
|
|||||||
file_hash = self.hash_url(url)
|
file_hash = self.hash_url(url)
|
||||||
return prefix + file_hash + extension
|
return prefix + file_hash + extension
|
||||||
|
|
||||||
|
def get_temp_base64_path(self, base64_encoding: str, prefix: str) -> str:
|
||||||
|
extension = get_extension(base64_encoding)
|
||||||
|
extension = "." + extension if extension else ""
|
||||||
|
base64_hash = self.hash_base64(base64_encoding)
|
||||||
|
return prefix + base64_hash + extension
|
||||||
|
|
||||||
def make_temp_copy_if_needed(self, file_path: str) -> str:
|
def make_temp_copy_if_needed(self, file_path: str) -> str:
|
||||||
"""Returns a temporary file path for a copy of the given file path if it does
|
"""Returns a temporary file path for a copy of the given file path if it does
|
||||||
not already exist. Otherwise returns the path to the existing temp file."""
|
not already exist. Otherwise returns the path to the existing temp file."""
|
||||||
@ -408,6 +421,27 @@ class TempFileManager:
|
|||||||
self.temp_files.add(full_temp_file_path)
|
self.temp_files.add(full_temp_file_path)
|
||||||
return full_temp_file_path
|
return full_temp_file_path
|
||||||
|
|
||||||
|
def base64_to_temp_file_if_needed(
|
||||||
|
self, base64_encoding: str, file_name: str | None = None
|
||||||
|
) -> str:
|
||||||
|
"""Converts a base64 encoding to a file and returns the path to the file if
|
||||||
|
the file doesn't already exist. Otherwise returns the path to the existing file."""
|
||||||
|
f = tempfile.NamedTemporaryFile(delete=False)
|
||||||
|
temp_dir = Path(f.name).parent
|
||||||
|
prefix = self.get_prefix_and_extension(file_name)[0] if file_name else ""
|
||||||
|
|
||||||
|
temp_file_path = self.get_temp_base64_path(base64_encoding, prefix=prefix)
|
||||||
|
f.name = str(temp_dir / temp_file_path)
|
||||||
|
full_temp_file_path = str(utils.abspath(f.name))
|
||||||
|
|
||||||
|
if not Path(full_temp_file_path).exists():
|
||||||
|
data, _ = decode_base64_to_binary(base64_encoding)
|
||||||
|
with open(full_temp_file_path, "wb") as fb:
|
||||||
|
fb.write(data)
|
||||||
|
|
||||||
|
self.temp_files.add(full_temp_file_path)
|
||||||
|
return full_temp_file_path
|
||||||
|
|
||||||
|
|
||||||
def download_tmp_copy_of_file(
|
def download_tmp_copy_of_file(
|
||||||
url_path: str, access_token: str | None = None, dir: str | None = None
|
url_path: str, access_token: str | None = None, dir: str | None = None
|
||||||
|
@ -770,9 +770,9 @@ class TestAudio:
|
|||||||
"""
|
"""
|
||||||
x_wav = deepcopy(media_data.BASE64_AUDIO)
|
x_wav = deepcopy(media_data.BASE64_AUDIO)
|
||||||
audio_input = gr.Audio()
|
audio_input = gr.Audio()
|
||||||
output = audio_input.preprocess(x_wav)
|
output1 = audio_input.preprocess(x_wav)
|
||||||
assert output[0] == 8000
|
assert output1[0] == 8000
|
||||||
assert output[1].shape == (8046,)
|
assert output1[1].shape == (8046,)
|
||||||
assert filecmp.cmp(
|
assert filecmp.cmp(
|
||||||
"test/test_files/audio_sample.wav",
|
"test/test_files/audio_sample.wav",
|
||||||
audio_input.serialize("test/test_files/audio_sample.wav")["name"],
|
audio_input.serialize("test/test_files/audio_sample.wav")["name"],
|
||||||
@ -796,7 +796,9 @@ class TestAudio:
|
|||||||
assert audio_input.preprocess(None) is None
|
assert audio_input.preprocess(None) is None
|
||||||
x_wav["is_example"] = True
|
x_wav["is_example"] = True
|
||||||
x_wav["crop_min"], x_wav["crop_max"] = 1, 4
|
x_wav["crop_min"], x_wav["crop_max"] = 1, 4
|
||||||
assert audio_input.preprocess(x_wav) is not None
|
output2 = audio_input.preprocess(x_wav)
|
||||||
|
assert output2 is not None
|
||||||
|
assert output1 != output2
|
||||||
|
|
||||||
audio_input = gr.Audio(type="filepath")
|
audio_input = gr.Audio(type="filepath")
|
||||||
assert isinstance(audio_input.preprocess(x_wav), str)
|
assert isinstance(audio_input.preprocess(x_wav), str)
|
||||||
|
@ -194,6 +194,30 @@ class TestTempFileManager:
|
|||||||
)
|
)
|
||||||
assert len(temp_file_manager.temp_files) == 2
|
assert len(temp_file_manager.temp_files) == 2
|
||||||
|
|
||||||
|
def test_base64_to_temp_file_if_needed(self):
|
||||||
|
temp_file_manager = processing_utils.TempFileManager()
|
||||||
|
|
||||||
|
base64_file_1 = media_data.BASE64_IMAGE
|
||||||
|
base64_file_2 = media_data.BASE64_AUDIO["data"]
|
||||||
|
|
||||||
|
f = temp_file_manager.base64_to_temp_file_if_needed(base64_file_1)
|
||||||
|
try: # Delete if already exists from before this test
|
||||||
|
os.remove(f)
|
||||||
|
except OSError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
f = temp_file_manager.base64_to_temp_file_if_needed(base64_file_1)
|
||||||
|
assert len(temp_file_manager.temp_files) == 1
|
||||||
|
|
||||||
|
f = temp_file_manager.base64_to_temp_file_if_needed(base64_file_1)
|
||||||
|
assert len(temp_file_manager.temp_files) == 1
|
||||||
|
|
||||||
|
f = temp_file_manager.base64_to_temp_file_if_needed(base64_file_2)
|
||||||
|
assert len(temp_file_manager.temp_files) == 2
|
||||||
|
|
||||||
|
for file in temp_file_manager.temp_files:
|
||||||
|
os.remove(file)
|
||||||
|
|
||||||
@pytest.mark.flaky
|
@pytest.mark.flaky
|
||||||
@patch("shutil.copyfileobj")
|
@patch("shutil.copyfileobj")
|
||||||
def test_download_temp_copy_if_needed(self, mock_copy):
|
def test_download_temp_copy_if_needed(self, mock_copy):
|
||||||
|
Loading…
Reference in New Issue
Block a user