mirror of
https://github.com/gradio-app/gradio.git
synced 2025-02-11 11:19:58 +08:00
* Update view api page * simplify * update * changes * changes * updated info * formatting * changes * fixes * save * moved * remove test input * tweaks * formatting * add raw * serialize * fixes * refactor * fixes * fixes * Fetch api * lower case * view api * fix tests * format * rough design * readme * api docs * examples * format * formatting * format * version * client changes * formatting * update client * more example inputs * api docs fixes * remove notebook * fix demo * demo notebook * styling on code snippet * formatting * fix audio, model3d * format * fix tests * version * cleanup * format * format * format * fixes * version * fix tests * version * format * test * format * changelog * changelog --------- Co-authored-by: freddyaboulton <alfonsoboulton@gmail.com> Co-authored-by: aliabd <ali.si3luwa@gmail.com>
115 lines
4.1 KiB
Python
115 lines
4.1 KiB
Python
import json
|
|
import tempfile
|
|
from copy import deepcopy
|
|
from pathlib import Path
|
|
from unittest.mock import MagicMock, patch
|
|
|
|
import pytest
|
|
from requests.exceptions import HTTPError
|
|
|
|
from gradio_client import media_data, utils
|
|
|
|
|
|
def test_encode_url_or_file_to_base64():
|
|
output_base64 = utils.encode_url_or_file_to_base64(
|
|
Path(__file__).parent / "../../../gradio/test_data/test_image.png"
|
|
)
|
|
assert output_base64 == deepcopy(media_data.BASE64_IMAGE)
|
|
|
|
|
|
def test_encode_file_to_base64():
|
|
output_base64 = utils.encode_file_to_base64(
|
|
Path(__file__).parent / "../../../gradio/test_data/test_image.png"
|
|
)
|
|
assert output_base64 == deepcopy(media_data.BASE64_IMAGE)
|
|
|
|
|
|
@pytest.mark.flaky
|
|
def test_encode_url_to_base64():
|
|
output_base64 = utils.encode_url_to_base64(
|
|
"https://raw.githubusercontent.com/gradio-app/gradio/main/gradio/test_data/test_image.png"
|
|
)
|
|
assert output_base64 == deepcopy(media_data.BASE64_IMAGE)
|
|
|
|
|
|
def test_decode_base64_to_binary():
|
|
binary = utils.decode_base64_to_binary(deepcopy(media_data.BASE64_IMAGE))
|
|
assert deepcopy(media_data.BINARY_IMAGE) == binary
|
|
|
|
|
|
def test_decode_base64_to_file():
|
|
temp_file = utils.decode_base64_to_file(deepcopy(media_data.BASE64_IMAGE))
|
|
assert isinstance(temp_file, tempfile._TemporaryFileWrapper)
|
|
|
|
|
|
def test_download_private_file():
|
|
url_path = "https://gradio-tests-not-actually-private-space.hf.space/file=lion.jpg"
|
|
hf_token = "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" # Intentionally revealing this key for testing purposes
|
|
file = utils.download_tmp_copy_of_file(url_path=url_path, hf_token=hf_token)
|
|
assert file.name.endswith(".jpg")
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"orig_filename, new_filename",
|
|
[
|
|
("abc", "abc"),
|
|
("$$AAabc&3", "AAabc3"),
|
|
("$$AAabc&3", "AAabc3"),
|
|
("$$AAa..b-c&3_", "AAa..b-c3_"),
|
|
("$$AAa..b-c&3_", "AAa..b-c3_"),
|
|
(
|
|
"ゆかりです。私、こんなかわいい服は初めて着ました…。なんだかうれしくって、楽しいです。歌いたくなる気分って、初めてです。これがアイドルってことなのかもしれませんね",
|
|
"ゆかりです私こんなかわいい服は初めて着ましたなんだかうれしくって楽しいです歌いたくなる気分って初めてですこれがアイドルってことなの",
|
|
),
|
|
],
|
|
)
|
|
def test_strip_invalid_filename_characters(orig_filename, new_filename):
|
|
assert utils.strip_invalid_filename_characters(orig_filename) == new_filename
|
|
|
|
|
|
class AsyncMock(MagicMock):
|
|
async def __call__(self, *args, **kwargs):
|
|
return super(AsyncMock, self).__call__(*args, **kwargs)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_pred_from_ws():
|
|
mock_ws = AsyncMock(name="ws")
|
|
messages = [
|
|
json.dumps({"msg": "estimation"}),
|
|
json.dumps({"msg": "send_data"}),
|
|
json.dumps({"msg": "process_generating"}),
|
|
json.dumps({"msg": "process_completed", "output": {"data": ["result!"]}}),
|
|
]
|
|
mock_ws.recv.side_effect = messages
|
|
data = json.dumps({"data": ["foo"], "fn_index": "foo"})
|
|
hash_data = json.dumps({"session_hash": "daslskdf", "fn_index": "foo"})
|
|
output = await utils.get_pred_from_ws(mock_ws, data, hash_data)
|
|
assert output == {"data": ["result!"]}
|
|
mock_ws.send.assert_called_once_with(data)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_pred_from_ws_raises_if_queue_full():
|
|
mock_ws = AsyncMock(name="ws")
|
|
messages = [json.dumps({"msg": "queue_full"})]
|
|
mock_ws.recv.side_effect = messages
|
|
data = json.dumps({"data": ["foo"], "fn_index": "foo"})
|
|
hash_data = json.dumps({"session_hash": "daslskdf", "fn_index": "foo"})
|
|
with pytest.raises(utils.QueueError, match="Queue is full!"):
|
|
await utils.get_pred_from_ws(mock_ws, data, hash_data)
|
|
|
|
|
|
@patch("requests.post")
|
|
def test_sleep_successful(mock_post):
|
|
utils.set_space_timeout("gradio/calculator")
|
|
|
|
|
|
@patch(
|
|
"requests.post",
|
|
return_value=MagicMock(raise_for_status=MagicMock(side_effect=HTTPError)),
|
|
)
|
|
def test_sleep_unsuccessful(mock_post):
|
|
with pytest.raises(utils.SpaceDuplicationError):
|
|
utils.set_space_timeout("gradio/calculator")
|