Cache temp files created from base64 data (#3197)

* changes

* added workflow

* fix action

* fix action

* fix action

* changelg

* formatting

* fix

* Delete benchmark-queue.yml

* Delete benchmark_queue.py

* changelog

* lint

* fix tests

* fix tests

* fix for python 3.7

* formatting
This commit is contained in:
Abubakar Abid 2023-02-15 17:24:48 -06:00 committed by GitHub
parent 74f75f004a
commit 752ec0ef6a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 86 additions and 26 deletions

View File

@ -55,6 +55,7 @@ No changes to highlight.
## Full Changelog:
* Fix demos page css and add close demos button by [@aliabd](https://github.com/aliabd) in [PR 3151](https://github.com/gradio-app/gradio/pull/3151)
* Caches temp files from base64 input data by giving them a deterministic path based on the contents of data by [@abidlabs](https://github.com/abidlabs) in [PR 3197](https://github.com/gradio-app/gradio/pull/3197)
* Better warnings (when there is a mismatch between the number of output components and values returned by a function, or when the `File` component or `UploadButton` component includes a `file_types` parameter along with `file_count=="dir"`) by [@abidlabs](https://github.com/abidlabs) in [PR 3194](https://github.com/gradio-app/gradio/pull/3194)
* Raises a `gr.Error` instead of a regular Python error when you use `gr.Interface.load()` to load a model and there's an error querying the HF API by [@abidlabs](https://github.com/abidlabs) in [PR 3194](https://github.com/gradio-app/gradio/pull/3194)

View File

@ -1808,13 +1808,9 @@ class Video(
x.get("is_file", False),
)
if is_file:
file = self.make_temp_copy_if_needed(file_name)
file_name = Path(file)
file_name = Path(self.make_temp_copy_if_needed(file_name))
else:
file = processing_utils.decode_base64_to_file(
file_data, file_path=file_name
)
file_name = Path(file.name)
file_name = Path(self.base64_to_temp_file_if_needed(file_data, file_name))
uploaded_format = file_name.suffix.replace(".", "")
modify_format = self.format is not None and uploaded_format != self.format
@ -2041,20 +2037,26 @@ class Audio(
else:
temp_file_path = self.make_temp_copy_if_needed(file_name)
else:
temp_file_obj = processing_utils.decode_base64_to_file(
file_data, file_path=file_name
)
temp_file_path = temp_file_obj.name
temp_file_path = self.base64_to_temp_file_if_needed(file_data, file_name)
sample_rate, data = processing_utils.audio_from_file(
temp_file_path, crop_min=crop_min, crop_max=crop_max
)
# Need a unique name for the file to avoid re-using the same audio file if
# a user submits the same audio file twice, but with different crop min/max.
temp_file_path = Path(temp_file_path)
output_file_name = str(
temp_file_path.with_name(
f"{temp_file_path.stem}-{crop_min}-{crop_max}{temp_file_path.suffix}"
)
)
if self.type == "numpy":
return sample_rate, data
elif self.type == "filepath":
processing_utils.audio_to_file(sample_rate, data, temp_file_path)
return temp_file_path
processing_utils.audio_to_file(sample_rate, data, output_file_name)
return output_file_name
else:
raise ValueError(
"Unknown type: "
@ -2075,8 +2077,8 @@ class Audio(
if x.get("is_file"):
sample_rate, data = processing_utils.audio_from_file(x["name"])
else:
file_obj = processing_utils.decode_base64_to_file(x["data"])
sample_rate, data = processing_utils.audio_from_file(file_obj.name)
file_name = self.base64_to_temp_file_if_needed(x["data"])
sample_rate, data = processing_utils.audio_from_file(file_name)
leave_one_out_sets = []
tokens = []
masks = []
@ -2117,14 +2119,14 @@ class Audio(
def get_masked_inputs(self, tokens, binary_mask_matrix):
# create a "zero input" vector and get sample rate
x = tokens[0]["data"]
file_obj = processing_utils.decode_base64_to_file(x)
sample_rate, data = processing_utils.audio_from_file(file_obj.name)
file_name = self.base64_to_temp_file_if_needed(x)
sample_rate, data = processing_utils.audio_from_file(file_name)
zero_input = np.zeros_like(data, dtype="int16")
# decode all of the tokens
token_data = []
for token in tokens:
file_obj = processing_utils.decode_base64_to_file(token["data"])
_, data = processing_utils.audio_from_file(file_obj.name)
file_name = self.base64_to_temp_file_if_needed(token["data"])
_, data = processing_utils.audio_from_file(file_name)
token_data.append(data)
# construct the masked version
masked_inputs = []
@ -4046,10 +4048,7 @@ class Model3D(
if is_file:
temp_file_path = self.make_temp_copy_if_needed(file_name)
else:
temp_file = processing_utils.decode_base64_to_file(
file_data, file_path=file_name
)
temp_file_path = temp_file.name
temp_file_path = self.base64_to_temp_file_if_needed(file_data, file_name)
return temp_file_path

View File

@ -353,6 +353,13 @@ class TempFileManager:
sha1.update(data)
return sha1.hexdigest()
def hash_base64(self, base64_encoding: str, chunk_num_blocks: int = 128) -> str:
sha1 = hashlib.sha1()
for i in range(0, len(base64_encoding), chunk_num_blocks * sha1.block_size):
data = base64_encoding[i : i + chunk_num_blocks * sha1.block_size]
sha1.update(data.encode("utf-8"))
return sha1.hexdigest()
def get_prefix_and_extension(self, file_path_or_url: str) -> Tuple[str, str]:
file_name = Path(file_path_or_url).name
prefix, extension = file_name, None
@ -374,6 +381,12 @@ class TempFileManager:
file_hash = self.hash_url(url)
return prefix + file_hash + extension
def get_temp_base64_path(self, base64_encoding: str, prefix: str) -> str:
extension = get_extension(base64_encoding)
extension = "." + extension if extension else ""
base64_hash = self.hash_base64(base64_encoding)
return prefix + base64_hash + extension
def make_temp_copy_if_needed(self, file_path: str) -> str:
"""Returns a temporary file path for a copy of the given file path if it does
not already exist. Otherwise returns the path to the existing temp file."""
@ -408,6 +421,27 @@ class TempFileManager:
self.temp_files.add(full_temp_file_path)
return full_temp_file_path
def base64_to_temp_file_if_needed(
self, base64_encoding: str, file_name: str | None = None
) -> str:
"""Converts a base64 encoding to a file and returns the path to the file if
the file doesn't already exist. Otherwise returns the path to the existing file."""
f = tempfile.NamedTemporaryFile(delete=False)
temp_dir = Path(f.name).parent
prefix = self.get_prefix_and_extension(file_name)[0] if file_name else ""
temp_file_path = self.get_temp_base64_path(base64_encoding, prefix=prefix)
f.name = str(temp_dir / temp_file_path)
full_temp_file_path = str(utils.abspath(f.name))
if not Path(full_temp_file_path).exists():
data, _ = decode_base64_to_binary(base64_encoding)
with open(full_temp_file_path, "wb") as fb:
fb.write(data)
self.temp_files.add(full_temp_file_path)
return full_temp_file_path
def download_tmp_copy_of_file(
url_path: str, access_token: str | None = None, dir: str | None = None

View File

@ -770,9 +770,9 @@ class TestAudio:
"""
x_wav = deepcopy(media_data.BASE64_AUDIO)
audio_input = gr.Audio()
output = audio_input.preprocess(x_wav)
assert output[0] == 8000
assert output[1].shape == (8046,)
output1 = audio_input.preprocess(x_wav)
assert output1[0] == 8000
assert output1[1].shape == (8046,)
assert filecmp.cmp(
"test/test_files/audio_sample.wav",
audio_input.serialize("test/test_files/audio_sample.wav")["name"],
@ -796,7 +796,9 @@ class TestAudio:
assert audio_input.preprocess(None) is None
x_wav["is_example"] = True
x_wav["crop_min"], x_wav["crop_max"] = 1, 4
assert audio_input.preprocess(x_wav) is not None
output2 = audio_input.preprocess(x_wav)
assert output2 is not None
assert output1 != output2
audio_input = gr.Audio(type="filepath")
assert isinstance(audio_input.preprocess(x_wav), str)

View File

@ -194,6 +194,30 @@ class TestTempFileManager:
)
assert len(temp_file_manager.temp_files) == 2
def test_base64_to_temp_file_if_needed(self):
temp_file_manager = processing_utils.TempFileManager()
base64_file_1 = media_data.BASE64_IMAGE
base64_file_2 = media_data.BASE64_AUDIO["data"]
f = temp_file_manager.base64_to_temp_file_if_needed(base64_file_1)
try: # Delete if already exists from before this test
os.remove(f)
except OSError:
pass
f = temp_file_manager.base64_to_temp_file_if_needed(base64_file_1)
assert len(temp_file_manager.temp_files) == 1
f = temp_file_manager.base64_to_temp_file_if_needed(base64_file_1)
assert len(temp_file_manager.temp_files) == 1
f = temp_file_manager.base64_to_temp_file_if_needed(base64_file_2)
assert len(temp_file_manager.temp_files) == 2
for file in temp_file_manager.temp_files:
os.remove(file)
@pytest.mark.flaky
@patch("shutil.copyfileobj")
def test_download_temp_copy_if_needed(self, mock_copy):