import os import shutil import tempfile from pathlib import Path from unittest.mock import patch import ffmpy import numpy as np import pytest from gradio_client import media_data from PIL import Image, ImageCms from gradio import components, data_classes, processing_utils, utils from gradio.route_utils import API_PREFIX class TestTempFileManagement: def test_hash_file(self): h1 = processing_utils.hash_file("gradio/test_data/cheetah1.jpg") h2 = processing_utils.hash_file("gradio/test_data/cheetah1-copy.jpg") h3 = processing_utils.hash_file("gradio/test_data/cheetah2.jpg") assert h1 == h2 assert h1 != h3 def test_make_temp_copy_if_needed(self, gradio_temp_dir): f = processing_utils.save_file_to_cache( "gradio/test_data/cheetah1.jpg", cache_dir=gradio_temp_dir ) try: # Delete if already exists from before this test os.remove(f) except OSError: pass f = processing_utils.save_file_to_cache( "gradio/test_data/cheetah1.jpg", cache_dir=gradio_temp_dir ) assert len([f for f in gradio_temp_dir.glob("**/*") if f.is_file()]) == 1 assert Path(f).name == "cheetah1.jpg" f = processing_utils.save_file_to_cache( "gradio/test_data/cheetah1.jpg", cache_dir=gradio_temp_dir ) assert len([f for f in gradio_temp_dir.glob("**/*") if f.is_file()]) == 1 f = processing_utils.save_file_to_cache( "gradio/test_data/cheetah1-copy.jpg", cache_dir=gradio_temp_dir ) assert len([f for f in gradio_temp_dir.glob("**/*") if f.is_file()]) == 2 assert Path(f).name == "cheetah1-copy.jpg" def test_save_b64_to_cache(self, gradio_temp_dir): base64_file_1 = media_data.BASE64_IMAGE base64_file_2 = media_data.BASE64_AUDIO["data"] f = processing_utils.save_base64_to_cache( base64_file_1, cache_dir=gradio_temp_dir ) try: # Delete if already exists from before this test os.remove(f) except OSError: pass f = processing_utils.save_base64_to_cache( base64_file_1, cache_dir=gradio_temp_dir ) assert len([f for f in gradio_temp_dir.glob("**/*") if f.is_file()]) == 1 f = processing_utils.save_base64_to_cache( base64_file_1, cache_dir=gradio_temp_dir ) assert len([f for f in gradio_temp_dir.glob("**/*") if f.is_file()]) == 1 f = processing_utils.save_base64_to_cache( base64_file_2, cache_dir=gradio_temp_dir ) assert len([f for f in gradio_temp_dir.glob("**/*") if f.is_file()]) == 2 @pytest.mark.flaky def test_ssrf_protected_download(self, gradio_temp_dir): url1 = "https://raw.githubusercontent.com/gradio-app/gradio/main/gradio/test_data/test_image.png" url2 = "https://raw.githubusercontent.com/gradio-app/gradio/main/gradio/test_data/cheetah1.jpg" f = processing_utils.save_url_to_cache(url1, cache_dir=gradio_temp_dir) try: # Delete if already exists from before this test os.remove(f) except OSError: pass f = processing_utils.save_url_to_cache(url1, cache_dir=gradio_temp_dir) assert len([f for f in gradio_temp_dir.glob("**/*") if f.is_file()]) == 1 f = processing_utils.save_url_to_cache(url1, cache_dir=gradio_temp_dir) assert len([f for f in gradio_temp_dir.glob("**/*") if f.is_file()]) == 1 f = processing_utils.save_url_to_cache(url2, cache_dir=gradio_temp_dir) assert len([f for f in gradio_temp_dir.glob("**/*") if f.is_file()]) == 2 @pytest.mark.flaky def test_ssrf_protected_download_with_redirect(self, gradio_temp_dir): url = "https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/bread_small.png" processing_utils.save_url_to_cache(url, cache_dir=gradio_temp_dir) assert len([f for f in gradio_temp_dir.glob("**/*") if f.is_file()]) == 1 class TestImagePreprocessing: def test_encode_plot_to_base64(self): with utils.MatplotlibBackendMananger(): import matplotlib.pyplot as plt plt.plot([1, 2, 3, 4]) output_base64 = processing_utils.encode_plot_to_base64(plt) assert output_base64.startswith( "" ) def test_save_pil_to_file_keeps_pnginfo(self, gradio_temp_dir): input_img = Image.open("gradio/test_data/test_image.png") input_img = input_img.convert("RGB") input_img.info = {"key1": "value1", "key2": "value2"} input_img.save(gradio_temp_dir / "test_test_image.png") file_obj = processing_utils.save_pil_to_cache( input_img, cache_dir=gradio_temp_dir, format="png" ) output_img = Image.open(file_obj) assert output_img.info == input_img.info def test_save_pil_to_file_keeps_all_gif_frames(self, gradio_temp_dir): input_img = Image.open("gradio/test_data/rectangles.gif") file_obj = processing_utils.save_pil_to_cache( input_img, cache_dir=gradio_temp_dir, format="gif" ) output_img = Image.open(file_obj) assert output_img.n_frames == input_img.n_frames == 3 # type: ignore def test_np_pil_encode_to_the_same(self, gradio_temp_dir): arr = np.random.randint(0, 255, size=(100, 100, 3), dtype=np.uint8) pil = Image.fromarray(arr) assert processing_utils.save_pil_to_cache( pil, cache_dir=gradio_temp_dir ) == processing_utils.save_img_array_to_cache(arr, cache_dir=gradio_temp_dir) def test_encode_pil_to_temp_file_metadata_color_profile(self, gradio_temp_dir): # Read image img = Image.open("gradio/test_data/test_image.png") img_metadata = Image.open("gradio/test_data/test_image.png") img_metadata.info = {"key1": "value1", "key2": "value2"} # Creating sRGB profile profile = ImageCms.createProfile("sRGB") profile2 = ImageCms.ImageCmsProfile(profile) img.save( gradio_temp_dir / "img_color_profile.png", icc_profile=profile2.tobytes() ) img_cp1 = Image.open(str(gradio_temp_dir / "img_color_profile.png")) # Creating XYZ profile profile = ImageCms.createProfile("XYZ") profile2 = ImageCms.ImageCmsProfile(profile) img.save( gradio_temp_dir / "img_color_profile_2.png", icc_profile=profile2.tobytes() ) img_cp2 = Image.open(str(gradio_temp_dir / "img_color_profile_2.png")) img_path = processing_utils.save_pil_to_cache( img, cache_dir=gradio_temp_dir, format="png" ) img_metadata_path = processing_utils.save_pil_to_cache( img_metadata, cache_dir=gradio_temp_dir, format="png" ) img_cp1_path = processing_utils.save_pil_to_cache( img_cp1, cache_dir=gradio_temp_dir, format="png" ) img_cp2_path = processing_utils.save_pil_to_cache( img_cp2, cache_dir=gradio_temp_dir, format="png" ) assert len({img_path, img_metadata_path, img_cp1_path, img_cp2_path}) == 4 def test_resize_and_crop(self): img = Image.open("gradio/test_data/test_image.png") new_img = processing_utils.resize_and_crop(img, (20, 20)) assert new_img.size == (20, 20) with pytest.raises(ValueError): processing_utils.resize_and_crop( **{"img": img, "size": (20, 20), "crop_type": "test"} ) class TestAudioPreprocessing: def test_audio_from_file(self): audio = processing_utils.audio_from_file("gradio/test_data/test_audio.wav") assert audio[0] == 22050 assert isinstance(audio[1], np.ndarray) def test_audio_to_file(self): audio = processing_utils.audio_from_file("gradio/test_data/test_audio.wav") processing_utils.audio_to_file(audio[0], audio[1], "test_audio_to_file") assert os.path.exists("test_audio_to_file") os.remove("test_audio_to_file") def test_convert_to_16_bit_wav(self): # Generate a random audio sample and set the amplitude audio = np.random.randint(-100, 100, size=(100), dtype="int16") audio[0] = -32767 audio[1] = 32766 audio_ = audio.astype("float64") audio_ = processing_utils.convert_to_16_bit_wav(audio_) assert np.allclose(audio, audio_) assert audio_.dtype == "int16" audio_ = audio.astype("float32") audio_ = processing_utils.convert_to_16_bit_wav(audio_) assert np.allclose(audio, audio_) assert audio_.dtype == "int16" audio_ = processing_utils.convert_to_16_bit_wav(audio) assert np.allclose(audio, audio_) assert audio_.dtype == "int16" class TestOutputPreprocessing: float_dtype_list = [ float, float, np.double, np.single, np.float32, np.float64, "float32", "float64", ] def test_float_conversion_dtype(self): """Test any conversion from a float dtype to an other.""" x = np.array([-1, 1]) # Test all combinations of dtypes conversions dtype_combin = np.array( np.meshgrid( TestOutputPreprocessing.float_dtype_list, TestOutputPreprocessing.float_dtype_list, ) ).T.reshape(-1, 2) for dtype_in, dtype_out in dtype_combin: x = x.astype(dtype_in) y = processing_utils._convert(x, dtype_out) assert y.dtype == np.dtype(dtype_out) def test_subclass_conversion(self): """Check subclass conversion behavior""" x = np.array([-1, 1]) for dtype in TestOutputPreprocessing.float_dtype_list: x = x.astype(dtype) y = processing_utils._convert(x, np.floating) assert y.dtype == x.dtype class TestVideoProcessing: def test_video_has_playable_codecs(self, test_file_dir): assert processing_utils.video_is_playable( str(test_file_dir / "video_sample.mp4") ) assert processing_utils.video_is_playable( str(test_file_dir / "video_sample.ogg") ) assert processing_utils.video_is_playable( str(test_file_dir / "video_sample.webm") ) assert not processing_utils.video_is_playable( str(test_file_dir / "bad_video_sample.mp4") ) def raise_ffmpy_runtime_exception(*args, **kwargs): raise ffmpy.FFRuntimeError("", "", "", "") # type: ignore @pytest.mark.parametrize( "exception_to_raise", [raise_ffmpy_runtime_exception, KeyError(), IndexError()] ) def test_video_has_playable_codecs_catches_exceptions( self, exception_to_raise, test_file_dir ): with ( patch("ffmpy.FFprobe.run", side_effect=exception_to_raise), tempfile.NamedTemporaryFile( suffix="out.avi", delete=False ) as tmp_not_playable_vid, ): shutil.copy( str(test_file_dir / "bad_video_sample.mp4"), tmp_not_playable_vid.name, ) assert processing_utils.video_is_playable(tmp_not_playable_vid.name) def test_convert_video_to_playable_mp4(self, test_file_dir): with tempfile.NamedTemporaryFile( suffix="out.avi", delete=False ) as tmp_not_playable_vid: shutil.copy( str(test_file_dir / "bad_video_sample.mp4"), tmp_not_playable_vid.name ) with patch("os.remove", wraps=os.remove) as mock_remove: playable_vid = processing_utils.convert_video_to_playable_mp4( tmp_not_playable_vid.name ) # check tempfile got deleted assert not Path(mock_remove.call_args[0][0]).exists() assert processing_utils.video_is_playable(playable_vid) @patch("ffmpy.FFmpeg.run", side_effect=raise_ffmpy_runtime_exception) def test_video_conversion_returns_original_video_if_fails( self, mock_run, test_file_dir ): with tempfile.NamedTemporaryFile( suffix="out.avi", delete=False ) as tmp_not_playable_vid: shutil.copy( str(test_file_dir / "bad_video_sample.mp4"), tmp_not_playable_vid.name ) playable_vid = processing_utils.convert_video_to_playable_mp4( tmp_not_playable_vid.name ) # If the conversion succeeded it'd be .mp4 assert Path(playable_vid).suffix == ".avi" def test_add_root_url(): data = { "file": { "path": "path", "url": f"{API_PREFIX}/file=path", "meta": {"_type": "gradio.FileData"}, }, "file2": { "path": "path2", "url": "https://www.gradio.app", "meta": {"_type": "gradio.FileData"}, }, } root_url = "http://localhost:7860" expected = { "file": { "path": "path", "url": f"{root_url}{API_PREFIX}/file=path", "meta": {"_type": "gradio.FileData"}, }, "file2": { "path": "path2", "url": "https://www.gradio.app", "meta": {"_type": "gradio.FileData"}, }, } assert processing_utils.add_root_url(data, root_url, None) == expected new_root_url = "https://1234.gradio.live" new_expected = { "file": { "path": "path", "url": f"{new_root_url}{API_PREFIX}/file=path", "meta": {"_type": "gradio.FileData"}, }, "file2": { "path": "path2", "url": "https://www.gradio.app", "meta": {"_type": "gradio.FileData"}, }, } assert ( processing_utils.add_root_url(expected, new_root_url, root_url) == new_expected ) def test_hash_url_encodes_url(): assert processing_utils.hash_url( "https://www.gradio.app/image 1.jpg" ) == processing_utils.hash_bytes(b"https://www.gradio.app/image 1.jpg") @pytest.mark.asyncio async def test_json_data_not_moved_to_cache(): data = data_classes.JsonData( root={ "file": { "path": "path", "url": f"{API_PREFIX}/file=path", "meta": {"_type": "gradio.FileData"}, } } ) assert ( processing_utils.move_files_to_cache(data, components.Number(), False) == data ) assert processing_utils.move_files_to_cache(data, components.Number(), True) == data assert ( await processing_utils.async_move_files_to_cache( data, components.Number(), False ) == data ) assert ( await processing_utils.async_move_files_to_cache( data, components.Number(), True ) == data ) def test_public_request_pass(): tempdir = tempfile.TemporaryDirectory() file = processing_utils.ssrf_protected_download( "https://en.wikipedia.org/static/images/icons/wikipedia.png", tempdir.name ) assert os.path.exists(file) assert os.path.getsize(file) == 13444 @pytest.mark.asyncio async def test_async_public_request_pass(): tempdir = tempfile.TemporaryDirectory() file = await processing_utils.async_ssrf_protected_download( "https://en.wikipedia.org/static/images/icons/wikipedia.png", tempdir.name ) assert os.path.exists(file) assert os.path.getsize(file) == 13444 def test_private_request_fail(): with pytest.raises(ValueError, match="failed validation"): tempdir = tempfile.TemporaryDirectory() processing_utils.ssrf_protected_download( "http://192.168.1.250.nip.io/image.png", tempdir.name ) @pytest.mark.asyncio async def test_async_private_request_fail(): with pytest.raises(ValueError, match="failed validation"): tempdir = tempfile.TemporaryDirectory() await processing_utils.async_ssrf_protected_download( "http://192.168.1.250.nip.io/image.png", tempdir.name )