mirror of
https://github.com/gradio-app/gradio.git
synced 2025-02-11 11:19:58 +08:00
* placeholder * changelog * added to readme * client * implement futures * utils * scripts * lint * reorg * scripts * serialization * cleanup * fns * serialize * cache * callbacks * updates * formatting * packaging * requirements * remove changelog * client * access token * formatting * deprecate * format backend * client replace * updates * moving from utils * remove code duplication * rm duplicates * simplify * galleryserializer * serializable * load serializers * fixing errors * errors * typing * tests * changelog * lint * fix lint * fixing files * formatting * type * fix type checking * changelog * changelog * Update client/python/gradio_client/client.py Co-authored-by: Lucain <lucainp@gmail.com> * formatting, tests * formatting, tests * gr.load * refactoring * refactoring' * formatting * formatting * tests * tests * fix tests * cleanup * added tests * adding scripts * formatting * address review comments * readme * serialize info * remove from changelog * version 0.0.2 released * lint * type fix * check * type issues * hf_token * update hf token * telemetry * docs, circle dependency * hf token * formatting * updates * sort * script * external * docs * formatting * fixes * scripts * requirements * fix tests * context * changes * formatting * fixes * format fix --------- Co-authored-by: Lucain <lucainp@gmail.com>
101 lines
3.7 KiB
Python
101 lines
3.7 KiB
Python
import json
|
|
import tempfile
|
|
from copy import deepcopy
|
|
from pathlib import Path
|
|
from unittest.mock import MagicMock
|
|
|
|
import pytest
|
|
from gradio import media_data
|
|
|
|
from gradio_client import utils
|
|
|
|
|
|
def test_encode_url_or_file_to_base64():
|
|
output_base64 = utils.encode_url_or_file_to_base64(
|
|
Path(__file__).parent / "../../../gradio/test_data/test_image.png"
|
|
)
|
|
assert output_base64 == deepcopy(media_data.BASE64_IMAGE)
|
|
|
|
|
|
def test_encode_file_to_base64():
|
|
output_base64 = utils.encode_file_to_base64(
|
|
Path(__file__).parent / "../../../gradio/test_data/test_image.png"
|
|
)
|
|
assert output_base64 == deepcopy(media_data.BASE64_IMAGE)
|
|
|
|
|
|
@pytest.mark.flaky
|
|
def test_encode_url_to_base64():
|
|
output_base64 = utils.encode_url_to_base64(
|
|
"https://raw.githubusercontent.com/gradio-app/gradio/main/gradio/test_data/test_image.png"
|
|
)
|
|
assert output_base64 == deepcopy(media_data.BASE64_IMAGE)
|
|
|
|
|
|
def test_decode_base64_to_binary():
|
|
binary = utils.decode_base64_to_binary(deepcopy(media_data.BASE64_IMAGE))
|
|
assert deepcopy(media_data.BINARY_IMAGE) == binary
|
|
|
|
|
|
def test_decode_base64_to_file():
|
|
temp_file = utils.decode_base64_to_file(deepcopy(media_data.BASE64_IMAGE))
|
|
assert isinstance(temp_file, tempfile._TemporaryFileWrapper)
|
|
|
|
|
|
def test_download_private_file():
|
|
url_path = "https://gradio-tests-not-actually-private-space.hf.space/file=lion.jpg"
|
|
hf_token = "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" # Intentionally revealing this key for testing purposes
|
|
file = utils.download_tmp_copy_of_file(url_path=url_path, hf_token=hf_token)
|
|
assert file.name.endswith(".jpg")
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"orig_filename, new_filename",
|
|
[
|
|
("abc", "abc"),
|
|
("$$AAabc&3", "AAabc3"),
|
|
("$$AAabc&3", "AAabc3"),
|
|
("$$AAa..b-c&3_", "AAa..b-c3_"),
|
|
("$$AAa..b-c&3_", "AAa..b-c3_"),
|
|
(
|
|
"ゆかりです。私、こんなかわいい服は初めて着ました…。なんだかうれしくって、楽しいです。歌いたくなる気分って、初めてです。これがアイドルってことなのかもしれませんね",
|
|
"ゆかりです私こんなかわいい服は初めて着ましたなんだかうれしくって楽しいです歌いたくなる気分って初めてですこれがアイドルってことなの",
|
|
),
|
|
],
|
|
)
|
|
def test_strip_invalid_filename_characters(orig_filename, new_filename):
|
|
assert utils.strip_invalid_filename_characters(orig_filename) == new_filename
|
|
|
|
|
|
class AsyncMock(MagicMock):
|
|
async def __call__(self, *args, **kwargs):
|
|
return super(AsyncMock, self).__call__(*args, **kwargs)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_pred_from_ws():
|
|
mock_ws = AsyncMock(name="ws")
|
|
messages = [
|
|
json.dumps({"msg": "estimation"}),
|
|
json.dumps({"msg": "send_data"}),
|
|
json.dumps({"msg": "process_generating"}),
|
|
json.dumps({"msg": "process_completed", "output": {"data": ["result!"]}}),
|
|
]
|
|
mock_ws.recv.side_effect = messages
|
|
data = json.dumps({"data": ["foo"], "fn_index": "foo"})
|
|
hash_data = json.dumps({"session_hash": "daslskdf", "fn_index": "foo"})
|
|
output = await utils.get_pred_from_ws(mock_ws, data, hash_data)
|
|
assert output == {"data": ["result!"]}
|
|
mock_ws.send.assert_called_once_with(data)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_pred_from_ws_raises_if_queue_full():
|
|
mock_ws = AsyncMock(name="ws")
|
|
messages = [json.dumps({"msg": "queue_full"})]
|
|
mock_ws.recv.side_effect = messages
|
|
data = json.dumps({"data": ["foo"], "fn_index": "foo"})
|
|
hash_data = json.dumps({"session_hash": "daslskdf", "fn_index": "foo"})
|
|
with pytest.raises(utils.QueueError, match="Queue is full!"):
|
|
await utils.get_pred_from_ws(mock_ws, data, hash_data)
|