Fix local tests (#3411)

* document embed params

* fix tests

* cleanup

* cleanup

* cleanup

* revert

* changelog
This commit is contained in:
Abubakar Abid 2023-03-07 14:30:04 -08:00 committed by GitHub
parent 2fd9b55b87
commit da9a9cfd35
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 18 additions and 13 deletions

View File

@ -15,7 +15,7 @@ No changes to highlight.
## Testing and Infrastructure Changes:
No changes to highlight.
- Fixes tests that were failing locally but passing on CI by [@abidlabs](https://github.com/abidlabs) in [PR 3411](https://github.com/gradio-app/gradio/pull/3411)
## Breaking Changes:

View File

@ -376,9 +376,12 @@ class TempFileManager:
file_hash = self.hash_url(url)
return prefix + file_hash + extension
def get_temp_base64_path(self, base64_encoding: str, prefix: str) -> str:
extension = get_extension(base64_encoding)
extension = "." + extension if extension else ""
def get_temp_base64_path(
self, base64_encoding: str, prefix: str, extension: str
) -> str:
guess_extension = get_extension(base64_encoding)
if not extension and guess_extension:
extension = "." + guess_extension
base64_hash = self.hash_base64(base64_encoding)
return prefix + base64_hash + extension
@ -436,9 +439,12 @@ class TempFileManager:
the file doesn't already exist. Otherwise returns the path to the existing file."""
f = tempfile.NamedTemporaryFile(delete=False)
temp_dir = Path(f.name).parent
prefix = self.get_prefix_and_extension(file_name)[0] if file_name else ""
temp_file_path = self.get_temp_base64_path(base64_encoding, prefix=prefix)
prefix, extension = (
self.get_prefix_and_extension(file_name) if file_name else ("", "")
)
temp_file_path = self.get_temp_base64_path(
base64_encoding, prefix=prefix, extension=extension
)
f.name = str(temp_dir / temp_file_path)
full_temp_file_path = str(utils.abspath(f.name))

View File

@ -842,12 +842,11 @@ class TestAudio:
def test_serialize(self):
audio_input = gr.Audio()
assert audio_input.serialize("test/test_files/audio_sample.wav") == {
"data": media_data.BASE64_AUDIO["data"],
"is_file": False,
"orig_name": "audio_sample.wav",
"name": "test/test_files/audio_sample.wav",
}
serialized_input = audio_input.serialize("test/test_files/audio_sample.wav")
assert serialized_input["data"] == media_data.BASE64_AUDIO["data"]
assert os.path.basename(serialized_input["name"]) == "audio_sample.wav"
assert serialized_input["orig_name"] == "audio_sample.wav"
assert not serialized_input["is_file"]
def test_tokenize(self):
"""