Cache temp files created from base64 data (#3197)

* changes

* added workflow

* fix action

* fix action

* fix action

* changelg

* formatting

* fix

* Delete benchmark-queue.yml

* Delete benchmark_queue.py

* changelog

* lint

* fix tests

* fix tests

* fix for python 3.7

* formatting
This commit is contained in:
Abubakar Abid 2023-02-15 17:24:48 -06:00 committed by GitHub
parent 74f75f004a
commit 752ec0ef6a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 86 additions and 26 deletions

View File

@ -55,6 +55,7 @@ No changes to highlight.
## Full Changelog: ## Full Changelog:
* Fix demos page css and add close demos button by [@aliabd](https://github.com/aliabd) in [PR 3151](https://github.com/gradio-app/gradio/pull/3151) * Fix demos page css and add close demos button by [@aliabd](https://github.com/aliabd) in [PR 3151](https://github.com/gradio-app/gradio/pull/3151)
* Caches temp files from base64 input data by giving them a deterministic path based on the contents of data by [@abidlabs](https://github.com/abidlabs) in [PR 3197](https://github.com/gradio-app/gradio/pull/3197)
* Better warnings (when there is a mismatch between the number of output components and values returned by a function, or when the `File` component or `UploadButton` component includes a `file_types` parameter along with `file_count=="dir"`) by [@abidlabs](https://github.com/abidlabs) in [PR 3194](https://github.com/gradio-app/gradio/pull/3194) * Better warnings (when there is a mismatch between the number of output components and values returned by a function, or when the `File` component or `UploadButton` component includes a `file_types` parameter along with `file_count=="dir"`) by [@abidlabs](https://github.com/abidlabs) in [PR 3194](https://github.com/gradio-app/gradio/pull/3194)
* Raises a `gr.Error` instead of a regular Python error when you use `gr.Interface.load()` to load a model and there's an error querying the HF API by [@abidlabs](https://github.com/abidlabs) in [PR 3194](https://github.com/gradio-app/gradio/pull/3194) * Raises a `gr.Error` instead of a regular Python error when you use `gr.Interface.load()` to load a model and there's an error querying the HF API by [@abidlabs](https://github.com/abidlabs) in [PR 3194](https://github.com/gradio-app/gradio/pull/3194)

View File

@ -1808,13 +1808,9 @@ class Video(
x.get("is_file", False), x.get("is_file", False),
) )
if is_file: if is_file:
file = self.make_temp_copy_if_needed(file_name) file_name = Path(self.make_temp_copy_if_needed(file_name))
file_name = Path(file)
else: else:
file = processing_utils.decode_base64_to_file( file_name = Path(self.base64_to_temp_file_if_needed(file_data, file_name))
file_data, file_path=file_name
)
file_name = Path(file.name)
uploaded_format = file_name.suffix.replace(".", "") uploaded_format = file_name.suffix.replace(".", "")
modify_format = self.format is not None and uploaded_format != self.format modify_format = self.format is not None and uploaded_format != self.format
@ -2041,20 +2037,26 @@ class Audio(
else: else:
temp_file_path = self.make_temp_copy_if_needed(file_name) temp_file_path = self.make_temp_copy_if_needed(file_name)
else: else:
temp_file_obj = processing_utils.decode_base64_to_file( temp_file_path = self.base64_to_temp_file_if_needed(file_data, file_name)
file_data, file_path=file_name
)
temp_file_path = temp_file_obj.name
sample_rate, data = processing_utils.audio_from_file( sample_rate, data = processing_utils.audio_from_file(
temp_file_path, crop_min=crop_min, crop_max=crop_max temp_file_path, crop_min=crop_min, crop_max=crop_max
) )
# Need a unique name for the file to avoid re-using the same audio file if
# a user submits the same audio file twice, but with different crop min/max.
temp_file_path = Path(temp_file_path)
output_file_name = str(
temp_file_path.with_name(
f"{temp_file_path.stem}-{crop_min}-{crop_max}{temp_file_path.suffix}"
)
)
if self.type == "numpy": if self.type == "numpy":
return sample_rate, data return sample_rate, data
elif self.type == "filepath": elif self.type == "filepath":
processing_utils.audio_to_file(sample_rate, data, temp_file_path) processing_utils.audio_to_file(sample_rate, data, output_file_name)
return temp_file_path return output_file_name
else: else:
raise ValueError( raise ValueError(
"Unknown type: " "Unknown type: "
@ -2075,8 +2077,8 @@ class Audio(
if x.get("is_file"): if x.get("is_file"):
sample_rate, data = processing_utils.audio_from_file(x["name"]) sample_rate, data = processing_utils.audio_from_file(x["name"])
else: else:
file_obj = processing_utils.decode_base64_to_file(x["data"]) file_name = self.base64_to_temp_file_if_needed(x["data"])
sample_rate, data = processing_utils.audio_from_file(file_obj.name) sample_rate, data = processing_utils.audio_from_file(file_name)
leave_one_out_sets = [] leave_one_out_sets = []
tokens = [] tokens = []
masks = [] masks = []
@ -2117,14 +2119,14 @@ class Audio(
def get_masked_inputs(self, tokens, binary_mask_matrix): def get_masked_inputs(self, tokens, binary_mask_matrix):
# create a "zero input" vector and get sample rate # create a "zero input" vector and get sample rate
x = tokens[0]["data"] x = tokens[0]["data"]
file_obj = processing_utils.decode_base64_to_file(x) file_name = self.base64_to_temp_file_if_needed(x)
sample_rate, data = processing_utils.audio_from_file(file_obj.name) sample_rate, data = processing_utils.audio_from_file(file_name)
zero_input = np.zeros_like(data, dtype="int16") zero_input = np.zeros_like(data, dtype="int16")
# decode all of the tokens # decode all of the tokens
token_data = [] token_data = []
for token in tokens: for token in tokens:
file_obj = processing_utils.decode_base64_to_file(token["data"]) file_name = self.base64_to_temp_file_if_needed(token["data"])
_, data = processing_utils.audio_from_file(file_obj.name) _, data = processing_utils.audio_from_file(file_name)
token_data.append(data) token_data.append(data)
# construct the masked version # construct the masked version
masked_inputs = [] masked_inputs = []
@ -4046,10 +4048,7 @@ class Model3D(
if is_file: if is_file:
temp_file_path = self.make_temp_copy_if_needed(file_name) temp_file_path = self.make_temp_copy_if_needed(file_name)
else: else:
temp_file = processing_utils.decode_base64_to_file( temp_file_path = self.base64_to_temp_file_if_needed(file_data, file_name)
file_data, file_path=file_name
)
temp_file_path = temp_file.name
return temp_file_path return temp_file_path

View File

@ -353,6 +353,13 @@ class TempFileManager:
sha1.update(data) sha1.update(data)
return sha1.hexdigest() return sha1.hexdigest()
def hash_base64(self, base64_encoding: str, chunk_num_blocks: int = 128) -> str:
sha1 = hashlib.sha1()
for i in range(0, len(base64_encoding), chunk_num_blocks * sha1.block_size):
data = base64_encoding[i : i + chunk_num_blocks * sha1.block_size]
sha1.update(data.encode("utf-8"))
return sha1.hexdigest()
def get_prefix_and_extension(self, file_path_or_url: str) -> Tuple[str, str]: def get_prefix_and_extension(self, file_path_or_url: str) -> Tuple[str, str]:
file_name = Path(file_path_or_url).name file_name = Path(file_path_or_url).name
prefix, extension = file_name, None prefix, extension = file_name, None
@ -374,6 +381,12 @@ class TempFileManager:
file_hash = self.hash_url(url) file_hash = self.hash_url(url)
return prefix + file_hash + extension return prefix + file_hash + extension
def get_temp_base64_path(self, base64_encoding: str, prefix: str) -> str:
extension = get_extension(base64_encoding)
extension = "." + extension if extension else ""
base64_hash = self.hash_base64(base64_encoding)
return prefix + base64_hash + extension
def make_temp_copy_if_needed(self, file_path: str) -> str: def make_temp_copy_if_needed(self, file_path: str) -> str:
"""Returns a temporary file path for a copy of the given file path if it does """Returns a temporary file path for a copy of the given file path if it does
not already exist. Otherwise returns the path to the existing temp file.""" not already exist. Otherwise returns the path to the existing temp file."""
@ -408,6 +421,27 @@ class TempFileManager:
self.temp_files.add(full_temp_file_path) self.temp_files.add(full_temp_file_path)
return full_temp_file_path return full_temp_file_path
def base64_to_temp_file_if_needed(
self, base64_encoding: str, file_name: str | None = None
) -> str:
"""Converts a base64 encoding to a file and returns the path to the file if
the file doesn't already exist. Otherwise returns the path to the existing file."""
f = tempfile.NamedTemporaryFile(delete=False)
temp_dir = Path(f.name).parent
prefix = self.get_prefix_and_extension(file_name)[0] if file_name else ""
temp_file_path = self.get_temp_base64_path(base64_encoding, prefix=prefix)
f.name = str(temp_dir / temp_file_path)
full_temp_file_path = str(utils.abspath(f.name))
if not Path(full_temp_file_path).exists():
data, _ = decode_base64_to_binary(base64_encoding)
with open(full_temp_file_path, "wb") as fb:
fb.write(data)
self.temp_files.add(full_temp_file_path)
return full_temp_file_path
def download_tmp_copy_of_file( def download_tmp_copy_of_file(
url_path: str, access_token: str | None = None, dir: str | None = None url_path: str, access_token: str | None = None, dir: str | None = None

View File

@ -770,9 +770,9 @@ class TestAudio:
""" """
x_wav = deepcopy(media_data.BASE64_AUDIO) x_wav = deepcopy(media_data.BASE64_AUDIO)
audio_input = gr.Audio() audio_input = gr.Audio()
output = audio_input.preprocess(x_wav) output1 = audio_input.preprocess(x_wav)
assert output[0] == 8000 assert output1[0] == 8000
assert output[1].shape == (8046,) assert output1[1].shape == (8046,)
assert filecmp.cmp( assert filecmp.cmp(
"test/test_files/audio_sample.wav", "test/test_files/audio_sample.wav",
audio_input.serialize("test/test_files/audio_sample.wav")["name"], audio_input.serialize("test/test_files/audio_sample.wav")["name"],
@ -796,7 +796,9 @@ class TestAudio:
assert audio_input.preprocess(None) is None assert audio_input.preprocess(None) is None
x_wav["is_example"] = True x_wav["is_example"] = True
x_wav["crop_min"], x_wav["crop_max"] = 1, 4 x_wav["crop_min"], x_wav["crop_max"] = 1, 4
assert audio_input.preprocess(x_wav) is not None output2 = audio_input.preprocess(x_wav)
assert output2 is not None
assert output1 != output2
audio_input = gr.Audio(type="filepath") audio_input = gr.Audio(type="filepath")
assert isinstance(audio_input.preprocess(x_wav), str) assert isinstance(audio_input.preprocess(x_wav), str)

View File

@ -194,6 +194,30 @@ class TestTempFileManager:
) )
assert len(temp_file_manager.temp_files) == 2 assert len(temp_file_manager.temp_files) == 2
def test_base64_to_temp_file_if_needed(self):
temp_file_manager = processing_utils.TempFileManager()
base64_file_1 = media_data.BASE64_IMAGE
base64_file_2 = media_data.BASE64_AUDIO["data"]
f = temp_file_manager.base64_to_temp_file_if_needed(base64_file_1)
try: # Delete if already exists from before this test
os.remove(f)
except OSError:
pass
f = temp_file_manager.base64_to_temp_file_if_needed(base64_file_1)
assert len(temp_file_manager.temp_files) == 1
f = temp_file_manager.base64_to_temp_file_if_needed(base64_file_1)
assert len(temp_file_manager.temp_files) == 1
f = temp_file_manager.base64_to_temp_file_if_needed(base64_file_2)
assert len(temp_file_manager.temp_files) == 2
for file in temp_file_manager.temp_files:
os.remove(file)
@pytest.mark.flaky @pytest.mark.flaky
@patch("shutil.copyfileobj") @patch("shutil.copyfileobj")
def test_download_temp_copy_if_needed(self, mock_copy): def test_download_temp_copy_if_needed(self, mock_copy):