diff --git a/CHANGELOG.md b/CHANGELOG.md index 6e0968a6b8..29bee2effb 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -55,6 +55,7 @@ No changes to highlight. ## Full Changelog: * Fix demos page css and add close demos button by [@aliabd](https://github.com/aliabd) in [PR 3151](https://github.com/gradio-app/gradio/pull/3151) +* Caches temp files from base64 input data by giving them a deterministic path based on the contents of data by [@abidlabs](https://github.com/abidlabs) in [PR 3197](https://github.com/gradio-app/gradio/pull/3197) * Better warnings (when there is a mismatch between the number of output components and values returned by a function, or when the `File` component or `UploadButton` component includes a `file_types` parameter along with `file_count=="dir"`) by [@abidlabs](https://github.com/abidlabs) in [PR 3194](https://github.com/gradio-app/gradio/pull/3194) * Raises a `gr.Error` instead of a regular Python error when you use `gr.Interface.load()` to load a model and there's an error querying the HF API by [@abidlabs](https://github.com/abidlabs) in [PR 3194](https://github.com/gradio-app/gradio/pull/3194) diff --git a/gradio/components.py b/gradio/components.py index c3ff6e7e38..330c7fdc45 100644 --- a/gradio/components.py +++ b/gradio/components.py @@ -1808,13 +1808,9 @@ class Video( x.get("is_file", False), ) if is_file: - file = self.make_temp_copy_if_needed(file_name) - file_name = Path(file) + file_name = Path(self.make_temp_copy_if_needed(file_name)) else: - file = processing_utils.decode_base64_to_file( - file_data, file_path=file_name - ) - file_name = Path(file.name) + file_name = Path(self.base64_to_temp_file_if_needed(file_data, file_name)) uploaded_format = file_name.suffix.replace(".", "") modify_format = self.format is not None and uploaded_format != self.format @@ -2041,20 +2037,26 @@ class Audio( else: temp_file_path = self.make_temp_copy_if_needed(file_name) else: - temp_file_obj = processing_utils.decode_base64_to_file( - file_data, file_path=file_name - ) - temp_file_path = temp_file_obj.name + temp_file_path = self.base64_to_temp_file_if_needed(file_data, file_name) sample_rate, data = processing_utils.audio_from_file( temp_file_path, crop_min=crop_min, crop_max=crop_max ) + # Need a unique name for the file to avoid re-using the same audio file if + # a user submits the same audio file twice, but with different crop min/max. + temp_file_path = Path(temp_file_path) + output_file_name = str( + temp_file_path.with_name( + f"{temp_file_path.stem}-{crop_min}-{crop_max}{temp_file_path.suffix}" + ) + ) + if self.type == "numpy": return sample_rate, data elif self.type == "filepath": - processing_utils.audio_to_file(sample_rate, data, temp_file_path) - return temp_file_path + processing_utils.audio_to_file(sample_rate, data, output_file_name) + return output_file_name else: raise ValueError( "Unknown type: " @@ -2075,8 +2077,8 @@ class Audio( if x.get("is_file"): sample_rate, data = processing_utils.audio_from_file(x["name"]) else: - file_obj = processing_utils.decode_base64_to_file(x["data"]) - sample_rate, data = processing_utils.audio_from_file(file_obj.name) + file_name = self.base64_to_temp_file_if_needed(x["data"]) + sample_rate, data = processing_utils.audio_from_file(file_name) leave_one_out_sets = [] tokens = [] masks = [] @@ -2117,14 +2119,14 @@ class Audio( def get_masked_inputs(self, tokens, binary_mask_matrix): # create a "zero input" vector and get sample rate x = tokens[0]["data"] - file_obj = processing_utils.decode_base64_to_file(x) - sample_rate, data = processing_utils.audio_from_file(file_obj.name) + file_name = self.base64_to_temp_file_if_needed(x) + sample_rate, data = processing_utils.audio_from_file(file_name) zero_input = np.zeros_like(data, dtype="int16") # decode all of the tokens token_data = [] for token in tokens: - file_obj = processing_utils.decode_base64_to_file(token["data"]) - _, data = processing_utils.audio_from_file(file_obj.name) + file_name = self.base64_to_temp_file_if_needed(token["data"]) + _, data = processing_utils.audio_from_file(file_name) token_data.append(data) # construct the masked version masked_inputs = [] @@ -4046,10 +4048,7 @@ class Model3D( if is_file: temp_file_path = self.make_temp_copy_if_needed(file_name) else: - temp_file = processing_utils.decode_base64_to_file( - file_data, file_path=file_name - ) - temp_file_path = temp_file.name + temp_file_path = self.base64_to_temp_file_if_needed(file_data, file_name) return temp_file_path diff --git a/gradio/processing_utils.py b/gradio/processing_utils.py index 272dd90f17..e228d96b2f 100644 --- a/gradio/processing_utils.py +++ b/gradio/processing_utils.py @@ -353,6 +353,13 @@ class TempFileManager: sha1.update(data) return sha1.hexdigest() + def hash_base64(self, base64_encoding: str, chunk_num_blocks: int = 128) -> str: + sha1 = hashlib.sha1() + for i in range(0, len(base64_encoding), chunk_num_blocks * sha1.block_size): + data = base64_encoding[i : i + chunk_num_blocks * sha1.block_size] + sha1.update(data.encode("utf-8")) + return sha1.hexdigest() + def get_prefix_and_extension(self, file_path_or_url: str) -> Tuple[str, str]: file_name = Path(file_path_or_url).name prefix, extension = file_name, None @@ -374,6 +381,12 @@ class TempFileManager: file_hash = self.hash_url(url) return prefix + file_hash + extension + def get_temp_base64_path(self, base64_encoding: str, prefix: str) -> str: + extension = get_extension(base64_encoding) + extension = "." + extension if extension else "" + base64_hash = self.hash_base64(base64_encoding) + return prefix + base64_hash + extension + def make_temp_copy_if_needed(self, file_path: str) -> str: """Returns a temporary file path for a copy of the given file path if it does not already exist. Otherwise returns the path to the existing temp file.""" @@ -408,6 +421,27 @@ class TempFileManager: self.temp_files.add(full_temp_file_path) return full_temp_file_path + def base64_to_temp_file_if_needed( + self, base64_encoding: str, file_name: str | None = None + ) -> str: + """Converts a base64 encoding to a file and returns the path to the file if + the file doesn't already exist. Otherwise returns the path to the existing file.""" + f = tempfile.NamedTemporaryFile(delete=False) + temp_dir = Path(f.name).parent + prefix = self.get_prefix_and_extension(file_name)[0] if file_name else "" + + temp_file_path = self.get_temp_base64_path(base64_encoding, prefix=prefix) + f.name = str(temp_dir / temp_file_path) + full_temp_file_path = str(utils.abspath(f.name)) + + if not Path(full_temp_file_path).exists(): + data, _ = decode_base64_to_binary(base64_encoding) + with open(full_temp_file_path, "wb") as fb: + fb.write(data) + + self.temp_files.add(full_temp_file_path) + return full_temp_file_path + def download_tmp_copy_of_file( url_path: str, access_token: str | None = None, dir: str | None = None diff --git a/test/test_components.py b/test/test_components.py index 0b367158ad..4f88f2b71a 100644 --- a/test/test_components.py +++ b/test/test_components.py @@ -770,9 +770,9 @@ class TestAudio: """ x_wav = deepcopy(media_data.BASE64_AUDIO) audio_input = gr.Audio() - output = audio_input.preprocess(x_wav) - assert output[0] == 8000 - assert output[1].shape == (8046,) + output1 = audio_input.preprocess(x_wav) + assert output1[0] == 8000 + assert output1[1].shape == (8046,) assert filecmp.cmp( "test/test_files/audio_sample.wav", audio_input.serialize("test/test_files/audio_sample.wav")["name"], @@ -796,7 +796,9 @@ class TestAudio: assert audio_input.preprocess(None) is None x_wav["is_example"] = True x_wav["crop_min"], x_wav["crop_max"] = 1, 4 - assert audio_input.preprocess(x_wav) is not None + output2 = audio_input.preprocess(x_wav) + assert output2 is not None + assert output1 != output2 audio_input = gr.Audio(type="filepath") assert isinstance(audio_input.preprocess(x_wav), str) diff --git a/test/test_processing_utils.py b/test/test_processing_utils.py index 7e5ea5ed44..7ffc46b0d4 100644 --- a/test/test_processing_utils.py +++ b/test/test_processing_utils.py @@ -194,6 +194,30 @@ class TestTempFileManager: ) assert len(temp_file_manager.temp_files) == 2 + def test_base64_to_temp_file_if_needed(self): + temp_file_manager = processing_utils.TempFileManager() + + base64_file_1 = media_data.BASE64_IMAGE + base64_file_2 = media_data.BASE64_AUDIO["data"] + + f = temp_file_manager.base64_to_temp_file_if_needed(base64_file_1) + try: # Delete if already exists from before this test + os.remove(f) + except OSError: + pass + + f = temp_file_manager.base64_to_temp_file_if_needed(base64_file_1) + assert len(temp_file_manager.temp_files) == 1 + + f = temp_file_manager.base64_to_temp_file_if_needed(base64_file_1) + assert len(temp_file_manager.temp_files) == 1 + + f = temp_file_manager.base64_to_temp_file_if_needed(base64_file_2) + assert len(temp_file_manager.temp_files) == 2 + + for file in temp_file_manager.temp_files: + os.remove(file) + @pytest.mark.flaky @patch("shutil.copyfileobj") def test_download_temp_copy_if_needed(self, mock_copy):