Fix processing_utils.save_url_to_cache() to follow redirects when accessing the URL (#7322)

* Fix `processing_utils.save_url_to_cache()` to follow redirects when accessing the URL

* add changeset

* follow more redirects

* format

* add changeset

* add test

* validate urls

---------

Co-authored-by: gradio-pr-bot <gradio-pr-bot@users.noreply.github.com>
Co-authored-by: Abubakar Abid <abubakar@huggingface.co>
This commit is contained in:
Yuichiro Tachibana (Tsuchiya) 2024-02-06 19:46:20 +00:00 committed by GitHub
parent 200e2518e4
commit b25e95e164
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 26 additions and 4 deletions

View File

@ -0,0 +1,6 @@
---
"gradio": patch
"gradio_client": patch
---
fix:Fix `processing_utils.save_url_to_cache()` to follow redirects when accessing the URL

View File

@ -633,7 +633,9 @@ def download_file(
temp_dir = Path(tempfile.gettempdir()) / secrets.token_hex(20)
temp_dir.mkdir(exist_ok=True, parents=True)
with httpx.stream("GET", url_path, headers=headers) as response:
with httpx.stream(
"GET", url_path, headers=headers, follow_redirects=True
) as response:
response.raise_for_status()
with open(temp_dir / Path(url_path).name, "wb") as f:
for chunk in response.iter_bytes(chunk_size=128 * sha1.block_size):
@ -666,7 +668,9 @@ def download_tmp_copy_of_file(
directory.mkdir(exist_ok=True, parents=True)
file_path = directory / Path(url_path).name
with httpx.stream("GET", url_path, headers=headers) as response:
with httpx.stream(
"GET", url_path, headers=headers, follow_redirects=True
) as response:
response.raise_for_status()
with open(file_path, "wb") as f:
for chunk in response.iter_raw():

View File

@ -190,7 +190,9 @@ def save_url_to_cache(url: str, cache_dir: str) -> str:
full_temp_file_path = str(abspath(temp_dir / name))
if not Path(full_temp_file_path).exists():
with httpx.stream("GET", url) as r, open(full_temp_file_path, "wb") as f:
with httpx.stream("GET", url, follow_redirects=True) as r, open(
full_temp_file_path, "wb"
) as f:
for chunk in r.iter_raw():
f.write(chunk)

View File

@ -592,7 +592,9 @@ def validate_url(possible_url: str) -> bool:
head_request = httpx.head(possible_url, headers=headers, follow_redirects=True)
# some URLs, such as AWS S3 presigned URLs, return a 405 or a 403 for HEAD requests
if head_request.status_code in (403, 405):
return httpx.get(possible_url, headers=headers).is_success
return httpx.get(
possible_url, headers=headers, follow_redirects=True
).is_success
return head_request.is_success
except Exception:
return False

View File

@ -101,6 +101,11 @@ class TestTempFileManagement:
processing_utils.save_url_to_cache(url, cache_dir=gradio_temp_dir)
assert len([f for f in gradio_temp_dir.glob("**/*") if f.is_file()]) == 1
def test_save_url_to_cache_with_redirect(self, gradio_temp_dir):
url = "https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/bread_small.png"
processing_utils.save_url_to_cache(url, cache_dir=gradio_temp_dir)
assert len([f for f in gradio_temp_dir.glob("**/*") if f.is_file()]) == 1
class TestImagePreprocessing:
def test_encode_plot_to_base64(self):

View File

@ -198,6 +198,9 @@ class TestValidateURL:
assert validate_url(
"https://upload.wikimedia.org/wikipedia/commons/b/b0/Bengal_tiger_%28Panthera_tigris_tigris%29_female_3_crop.jpg"
)
assert validate_url(
"https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/bread_small.png"
)
def test_invalid_urls(self):
assert not (validate_url("C:/Users/"))