2022-11-08 08:37:55 +08:00
"""
2022-11-16 05:23:47 +08:00
Tests for all of the components defined in components . py . Tests are divided into two types :
2022-11-08 08:37:55 +08:00
1. test_component_functions ( ) are unit tests that check essential functions of a component , the functions that are checked are documented in the docstring .
2. test_in_interface ( ) are functional tests that check a component ' s functionalities inside an Interface. Please do not use Interface.launch() in this file, as it slow downs the tests.
"""
2022-08-23 23:31:04 +08:00
import filecmp
2022-03-30 04:23:30 +08:00
import json
import os
2023-03-09 04:24:09 +08:00
import pathlib # noqa: F401
2022-08-17 01:21:13 +08:00
import shutil
2022-03-30 04:23:30 +08:00
import tempfile
2022-04-05 17:54:17 +08:00
from copy import deepcopy
2022-03-30 04:23:30 +08:00
from difflib import SequenceMatcher
2023-03-09 04:24:09 +08:00
from pathlib import Path
2022-12-16 04:37:09 +08:00
from unittest . mock import MagicMock , patch
2022-03-30 04:23:30 +08:00
2022-11-08 08:37:55 +08:00
import matplotlib
2022-05-23 04:54:04 +08:00
import matplotlib . pyplot as plt
2022-03-30 04:23:30 +08:00
import numpy as np
import pandas as pd
import PIL
2022-07-14 22:28:47 +08:00
import pytest
2022-12-09 23:14:07 +08:00
import vega_datasets
2023-03-24 06:33:44 +08:00
from gradio_client import utils as client_utils
2022-08-10 04:40:31 +08:00
from scipy . io import wavfile
2022-03-30 20:34:08 +08:00
2022-03-30 04:23:30 +08:00
import gradio as gr
2022-08-17 01:21:13 +08:00
from gradio import media_data , processing_utils
2022-03-30 04:23:30 +08:00
os . environ [ " GRADIO_ANALYTICS_ENABLED " ] = " False "
2022-11-08 08:37:55 +08:00
matplotlib . use ( " Agg " )
2022-03-30 04:23:30 +08:00
2022-04-05 17:54:17 +08:00
2022-11-08 08:37:55 +08:00
class TestComponent :
2022-04-15 16:18:56 +08:00
def test_component_functions ( self ) :
"""
component
"""
2022-11-05 00:08:17 +08:00
assert isinstance ( gr . components . component ( " textarea " ) , gr . templates . TextArea )
2022-04-15 16:18:56 +08:00
2022-07-14 22:28:47 +08:00
def test_raise_warnings ( ) :
for c_type , component in zip (
[ " inputs " , " outputs " ] , [ gr . inputs . Textbox , gr . outputs . Label ]
) :
with pytest . warns ( UserWarning , match = f " Usage of gradio. { c_type } " ) :
component ( )
2022-11-08 08:37:55 +08:00
class TestTextbox :
2022-04-06 04:35:04 +08:00
def test_component_functions ( self ) :
2022-04-05 17:54:17 +08:00
"""
2023-03-17 05:22:25 +08:00
Preprocess , postprocess , serialize , tokenize , get_config
2022-04-05 17:54:17 +08:00
"""
2022-03-30 04:23:30 +08:00
text_input = gr . Textbox ( )
2022-11-08 08:37:55 +08:00
assert text_input . preprocess ( " Hello World! " ) == " Hello World! "
assert text_input . postprocess ( " Hello World! " ) == " Hello World! "
assert text_input . postprocess ( None ) is None
assert text_input . postprocess ( " Ali " ) == " Ali "
assert text_input . postprocess ( 2 ) == " 2 "
assert text_input . postprocess ( 2.14 ) == " 2.14 "
assert text_input . serialize ( " Hello World! " , True ) == " Hello World! "
assert text_input . tokenize ( " Hello World! Gradio speaking. " ) == (
[ " Hello " , " World! " , " Gradio " , " speaking. " ] ,
[
" World! Gradio speaking. " ,
" Hello Gradio speaking. " ,
" Hello World! speaking. " ,
" Hello World! Gradio " ,
] ,
None ,
2022-03-30 04:23:30 +08:00
)
text_input . interpretation_replacement = " unknown "
2022-11-08 08:37:55 +08:00
assert text_input . tokenize ( " Hello World! Gradio speaking. " ) == (
[ " Hello " , " World! " , " Gradio " , " speaking. " ] ,
[
" unknown World! Gradio speaking. " ,
" Hello unknown Gradio speaking. " ,
" Hello World! unknown speaking. " ,
" Hello World! Gradio unknown " ,
] ,
None ,
)
assert text_input . get_config ( ) == {
" lines " : 1 ,
" max_lines " : 20 ,
" placeholder " : None ,
" value " : " " ,
" name " : " textbox " ,
" show_label " : True ,
2022-11-16 05:23:47 +08:00
" type " : " text " ,
2022-11-08 08:37:55 +08:00
" label " : None ,
" style " : { } ,
" elem_id " : None ,
2023-03-16 05:01:53 +08:00
" elem_classes " : None ,
2022-11-08 08:37:55 +08:00
" visible " : True ,
" interactive " : None ,
" root_url " : None ,
}
2022-03-30 04:23:30 +08:00
2022-11-08 08:37:55 +08:00
@pytest.mark.asyncio
2022-08-11 06:29:14 +08:00
async def test_in_interface_as_input ( self ) :
2022-04-05 17:54:17 +08:00
"""
Interface , process , interpret ,
"""
2022-03-30 04:23:30 +08:00
iface = gr . Interface ( lambda x : x [ : : - 1 ] , " textbox " , " textbox " )
2022-11-08 08:37:55 +08:00
assert iface ( " Hello " ) == " olleH "
2022-03-30 04:23:30 +08:00
iface = gr . Interface (
lambda sentence : max ( [ len ( word ) for word in sentence . split ( ) ] ) ,
gr . Textbox ( ) ,
" number " ,
interpretation = " default " ,
)
2022-08-11 06:29:14 +08:00
scores = await iface . interpret (
2022-04-06 03:58:17 +08:00
[ " Return the length of the longest word in this sentence " ]
2022-03-30 04:23:30 +08:00
)
2022-11-08 08:37:55 +08:00
assert scores [ 0 ] [ " interpretation " ] == [
( " Return " , 0.0 ) ,
( " " , 0 ) ,
( " the " , 0.0 ) ,
( " " , 0 ) ,
( " length " , 0.0 ) ,
( " " , 0 ) ,
( " of " , 0.0 ) ,
( " " , 0 ) ,
( " the " , 0.0 ) ,
( " " , 0 ) ,
( " longest " , 0.0 ) ,
( " " , 0 ) ,
( " word " , 0.0 ) ,
( " " , 0 ) ,
( " in " , 0.0 ) ,
( " " , 0 ) ,
( " this " , 0.0 ) ,
( " " , 0 ) ,
( " sentence " , 1.0 ) ,
( " " , 0 ) ,
]
2022-03-30 04:23:30 +08:00
2022-11-08 08:37:55 +08:00
def test_in_interface_as_output ( self ) :
2022-04-05 17:54:17 +08:00
"""
Interface , process
"""
2022-03-30 04:23:30 +08:00
iface = gr . Interface ( lambda x : x [ - 1 ] , " textbox " , gr . Textbox ( ) )
2022-11-08 08:37:55 +08:00
assert iface ( " Hello " ) == " o "
2022-03-30 04:23:30 +08:00
iface = gr . Interface ( lambda x : x / 2 , " number " , gr . Textbox ( ) )
2022-11-08 08:37:55 +08:00
assert iface ( 10 ) == " 5.0 "
2022-03-30 04:23:30 +08:00
2022-05-21 08:53:27 +08:00
def test_static ( self ) :
"""
postprocess
"""
component = gr . Textbox ( " abc " )
2022-11-08 08:37:55 +08:00
assert component . get_config ( ) . get ( " value " ) == " abc "
2022-05-21 08:53:27 +08:00
2022-11-05 00:08:17 +08:00
def test_override_template ( self ) :
"""
override template
"""
component = gr . TextArea ( value = " abc " )
2022-11-08 08:37:55 +08:00
assert component . get_config ( ) . get ( " value " ) == " abc "
assert component . get_config ( ) . get ( " lines " ) == 7
2022-11-05 00:08:17 +08:00
component = gr . TextArea ( value = " abc " , lines = 4 )
2022-11-08 08:37:55 +08:00
assert component . get_config ( ) . get ( " value " ) == " abc "
assert component . get_config ( ) . get ( " lines " ) == 4
2022-11-05 00:08:17 +08:00
2022-11-16 05:23:47 +08:00
def test_faulty_type ( self ) :
with pytest . raises (
ValueError , match = ' `type` must be one of " text " , " password " , or " email " . '
) :
gr . Textbox ( type = " boo " )
def test_max_lines ( self ) :
assert gr . Textbox ( type = " password " ) . get_config ( ) . get ( " max_lines " ) == 1
assert gr . Textbox ( type = " email " ) . get_config ( ) . get ( " max_lines " ) == 1
assert gr . Textbox ( type = " text " ) . get_config ( ) . get ( " max_lines " ) == 20
assert gr . Textbox ( ) . get_config ( ) . get ( " max_lines " ) == 20
2022-03-30 04:23:30 +08:00
2022-11-08 08:37:55 +08:00
class TestNumber :
2022-04-06 04:35:04 +08:00
def test_component_functions ( self ) :
2022-04-05 17:54:17 +08:00
"""
2023-03-17 05:22:25 +08:00
Preprocess , postprocess , serialize , set_interpret_parameters , get_interpretation_neighbors , get_config
2022-04-05 17:54:17 +08:00
"""
2023-03-16 05:01:53 +08:00
numeric_input = gr . Number ( elem_id = " num " , elem_classes = " first " )
2022-11-08 08:37:55 +08:00
assert numeric_input . preprocess ( 3 ) == 3.0
assert numeric_input . preprocess ( None ) is None
assert numeric_input . postprocess ( 3 ) == 3
assert numeric_input . postprocess ( 3 ) == 3.0
assert numeric_input . postprocess ( 2.14 ) == 2.14
assert numeric_input . postprocess ( None ) is None
assert numeric_input . serialize ( 3 , True ) == 3
2022-03-30 04:23:30 +08:00
numeric_input . set_interpret_parameters ( steps = 3 , delta = 1 , delta_type = " absolute " )
2022-11-08 08:37:55 +08:00
assert numeric_input . get_interpretation_neighbors ( 1 ) == (
[ - 2.0 , - 1.0 , 0.0 , 2.0 , 3.0 , 4.0 ] ,
{ } ,
2022-03-30 04:23:30 +08:00
)
numeric_input . set_interpret_parameters ( steps = 3 , delta = 1 , delta_type = " percent " )
2022-11-08 08:37:55 +08:00
assert numeric_input . get_interpretation_neighbors ( 1 ) == (
[ 0.97 , 0.98 , 0.99 , 1.01 , 1.02 , 1.03 ] ,
{ } ,
)
assert numeric_input . get_config ( ) == {
" value " : None ,
" name " : " number " ,
" show_label " : True ,
" label " : None ,
" style " : { } ,
2023-03-16 05:01:53 +08:00
" elem_id " : " num " ,
" elem_classes " : [ " first " ] ,
2022-11-08 08:37:55 +08:00
" visible " : True ,
" interactive " : None ,
" root_url " : None ,
}
2022-03-30 04:23:30 +08:00
2022-05-14 10:04:11 +08:00
def test_component_functions_integer ( self ) :
"""
2023-03-17 05:22:25 +08:00
Preprocess , postprocess , serialize , set_interpret_parameters , get_interpretation_neighbors , get_template_context
2022-05-14 10:04:11 +08:00
"""
numeric_input = gr . Number ( precision = 0 , value = 42 )
2022-11-08 08:37:55 +08:00
assert numeric_input . preprocess ( 3 ) == 3
assert numeric_input . preprocess ( None ) is None
assert numeric_input . postprocess ( 3 ) == 3
assert numeric_input . postprocess ( 3 ) == 3
assert numeric_input . postprocess ( 2.85 ) == 3
assert numeric_input . postprocess ( None ) is None
assert numeric_input . serialize ( 3 , True ) == 3
2022-05-14 10:04:11 +08:00
numeric_input . set_interpret_parameters ( steps = 3 , delta = 1 , delta_type = " absolute " )
2022-11-08 08:37:55 +08:00
assert numeric_input . get_interpretation_neighbors ( 1 ) == (
[ - 2.0 , - 1.0 , 0.0 , 2.0 , 3.0 , 4.0 ] ,
{ } ,
2022-05-14 10:04:11 +08:00
)
numeric_input . set_interpret_parameters ( steps = 3 , delta = 1 , delta_type = " percent " )
2022-11-08 08:37:55 +08:00
assert numeric_input . get_interpretation_neighbors ( 100 ) == (
[ 97.0 , 98.0 , 99.0 , 101.0 , 102.0 , 103.0 ] ,
{ } ,
2022-05-14 10:04:11 +08:00
)
2022-11-08 08:37:55 +08:00
with pytest . raises ( ValueError ) as error :
2022-05-14 10:04:11 +08:00
numeric_input . get_interpretation_neighbors ( 1 )
assert error . msg == " Cannot generate valid set of neighbors "
numeric_input . set_interpret_parameters (
steps = 3 , delta = 1.24 , delta_type = " absolute "
)
2022-11-08 08:37:55 +08:00
with pytest . raises ( ValueError ) as error :
2022-05-14 10:04:11 +08:00
numeric_input . get_interpretation_neighbors ( 4 )
assert error . msg == " Cannot generate valid set of neighbors "
2022-11-08 08:37:55 +08:00
assert numeric_input . get_config ( ) == {
" value " : 42 ,
" name " : " number " ,
" show_label " : True ,
" label " : None ,
" style " : { } ,
" elem_id " : None ,
2023-03-16 05:01:53 +08:00
" elem_classes " : None ,
2022-11-08 08:37:55 +08:00
" visible " : True ,
" interactive " : None ,
" root_url " : None ,
}
2022-05-14 10:04:11 +08:00
def test_component_functions_precision ( self ) :
"""
2023-03-17 05:22:25 +08:00
Preprocess , postprocess , serialize , set_interpret_parameters , get_interpretation_neighbors , get_template_context
2022-05-14 10:04:11 +08:00
"""
numeric_input = gr . Number ( precision = 2 , value = 42.3428 )
2022-11-08 08:37:55 +08:00
assert numeric_input . preprocess ( 3.231241 ) == 3.23
assert numeric_input . preprocess ( None ) is None
assert numeric_input . postprocess ( - 42.1241 ) == - 42.12
assert numeric_input . postprocess ( 5.6784 ) == 5.68
assert numeric_input . postprocess ( 2.1421 ) == 2.14
assert numeric_input . postprocess ( None ) is None
2022-05-14 10:04:11 +08:00
2022-11-08 08:37:55 +08:00
@pytest.mark.asyncio
2022-08-11 06:29:14 +08:00
async def test_in_interface_as_input ( self ) :
2022-04-05 17:54:17 +08:00
"""
Interface , process , interpret
"""
2022-04-05 18:08:53 +08:00
iface = gr . Interface ( lambda x : x * * 2 , " number " , " textbox " )
2022-11-08 08:37:55 +08:00
assert iface ( 2 ) == " 4.0 "
2022-03-30 04:23:30 +08:00
iface = gr . Interface (
2022-04-05 18:08:53 +08:00
lambda x : x * * 2 , " number " , " number " , interpretation = " default "
2022-04-05 17:54:17 +08:00
)
2022-08-11 06:29:14 +08:00
scores = ( await iface . interpret ( [ 2 ] ) ) [ 0 ] [ " interpretation " ]
2022-11-08 08:37:55 +08:00
assert scores == [
( 1.94 , - 0.23640000000000017 ) ,
( 1.96 , - 0.15840000000000032 ) ,
( 1.98 , - 0.07960000000000012 ) ,
2022-12-30 01:33:12 +08:00
( 2 , None ) ,
2022-11-08 08:37:55 +08:00
( 2.02 , 0.08040000000000003 ) ,
( 2.04 , 0.16159999999999997 ) ,
( 2.06 , 0.24359999999999982 ) ,
]
2022-04-05 17:54:17 +08:00
2022-11-08 08:37:55 +08:00
@pytest.mark.asyncio
2022-08-11 06:29:14 +08:00
async def test_precision_0_in_interface ( self ) :
2022-05-14 10:04:11 +08:00
"""
Interface , process , interpret
"""
iface = gr . Interface ( lambda x : x * * 2 , gr . Number ( precision = 0 ) , " textbox " )
2022-11-08 08:37:55 +08:00
assert iface ( 2 ) == " 4 "
2022-05-14 10:04:11 +08:00
iface = gr . Interface (
lambda x : x * * 2 , " number " , gr . Number ( precision = 0 ) , interpretation = " default "
)
# Output gets rounded to 4 for all input so no change
2022-08-11 06:29:14 +08:00
scores = ( await iface . interpret ( [ 2 ] ) ) [ 0 ] [ " interpretation " ]
2022-11-08 08:37:55 +08:00
assert scores == [
( 1.94 , 0.0 ) ,
( 1.96 , 0.0 ) ,
( 1.98 , 0.0 ) ,
2022-12-30 01:33:12 +08:00
( 2 , None ) ,
2022-11-08 08:37:55 +08:00
( 2.02 , 0.0 ) ,
( 2.04 , 0.0 ) ,
( 2.06 , 0.0 ) ,
]
2022-05-14 10:04:11 +08:00
2022-11-08 08:37:55 +08:00
@pytest.mark.asyncio
2022-08-11 06:29:14 +08:00
async def test_in_interface_as_output ( self ) :
2022-04-05 17:54:17 +08:00
"""
Interface , process , interpret
"""
iface = gr . Interface ( lambda x : int ( x ) * * 2 , " textbox " , " number " )
2022-11-08 08:37:55 +08:00
assert iface ( 2 ) == 4.0
2022-04-05 17:54:17 +08:00
iface = gr . Interface (
2022-04-05 18:08:53 +08:00
lambda x : x * * 2 , " number " , " number " , interpretation = " default "
2022-03-30 04:23:30 +08:00
)
2022-08-11 06:29:14 +08:00
scores = ( await iface . interpret ( [ 2 ] ) ) [ 0 ] [ " interpretation " ]
2022-11-08 08:37:55 +08:00
assert scores == [
( 1.94 , - 0.23640000000000017 ) ,
( 1.96 , - 0.15840000000000032 ) ,
( 1.98 , - 0.07960000000000012 ) ,
2022-12-30 01:33:12 +08:00
( 2 , None ) ,
2022-11-08 08:37:55 +08:00
( 2.02 , 0.08040000000000003 ) ,
( 2.04 , 0.16159999999999997 ) ,
( 2.06 , 0.24359999999999982 ) ,
]
2022-03-30 04:23:30 +08:00
2022-05-21 08:53:27 +08:00
def test_static ( self ) :
"""
postprocess
"""
component = gr . Number ( )
2022-11-08 08:37:55 +08:00
assert component . get_config ( ) . get ( " value " ) is None
2022-05-21 08:53:27 +08:00
component = gr . Number ( 3 )
2022-11-08 08:37:55 +08:00
assert component . get_config ( ) . get ( " value " ) == 3.0
2022-05-21 08:53:27 +08:00
2022-03-30 04:23:30 +08:00
2022-11-08 08:37:55 +08:00
class TestSlider :
2022-04-06 04:35:04 +08:00
def test_component_functions ( self ) :
2022-04-05 17:54:17 +08:00
"""
2023-03-17 05:22:25 +08:00
Preprocess , postprocess , serialize , get_config
2022-04-05 17:54:17 +08:00
"""
2022-03-30 04:23:30 +08:00
slider_input = gr . Slider ( )
2022-11-08 08:37:55 +08:00
assert slider_input . preprocess ( 3.0 ) == 3.0
assert slider_input . postprocess ( 3 ) == 3
assert slider_input . postprocess ( 3 ) == 3
assert slider_input . postprocess ( None ) == 0
assert slider_input . serialize ( 3 , True ) == 3
2022-03-30 04:23:30 +08:00
2022-05-16 14:55:35 +08:00
slider_input = gr . Slider ( 10 , 20 , value = 15 , step = 1 , label = " Slide Your Input " )
2022-11-08 08:37:55 +08:00
assert slider_input . get_config ( ) == {
" minimum " : 10 ,
" maximum " : 20 ,
" step " : 1 ,
" value " : 15 ,
" name " : " slider " ,
" show_label " : True ,
" label " : " Slide Your Input " ,
" style " : { } ,
" elem_id " : None ,
2023-03-16 05:01:53 +08:00
" elem_classes " : None ,
2022-11-08 08:37:55 +08:00
" visible " : True ,
" interactive " : None ,
" root_url " : None ,
}
2022-03-30 04:23:30 +08:00
2022-11-08 08:37:55 +08:00
@pytest.mark.asyncio
2022-08-11 06:29:14 +08:00
async def test_in_interface ( self ) :
2022-12-13 02:10:39 +08:00
""" "
2022-04-05 17:54:17 +08:00
Interface , process , interpret
"""
2022-04-05 18:08:53 +08:00
iface = gr . Interface ( lambda x : x * * 2 , " slider " , " textbox " )
2022-11-08 08:37:55 +08:00
assert iface ( 2 ) == " 4 "
2022-03-30 04:23:30 +08:00
iface = gr . Interface (
2022-04-05 18:08:53 +08:00
lambda x : x * * 2 , " slider " , " number " , interpretation = " default "
2022-03-30 04:23:30 +08:00
)
2022-08-11 06:29:14 +08:00
scores = ( await iface . interpret ( [ 2 ] ) ) [ 0 ] [ " interpretation " ]
2022-11-08 08:37:55 +08:00
assert scores == [
- 4.0 ,
200.08163265306123 ,
812.3265306122449 ,
1832.7346938775513 ,
3261.3061224489797 ,
5098.040816326531 ,
7342.938775510205 ,
9996.0 ,
]
2022-03-30 04:23:30 +08:00
2022-05-21 08:53:27 +08:00
def test_static ( self ) :
"""
postprocess
"""
component = gr . Slider ( 0 , 100 , 5 )
2022-11-08 08:37:55 +08:00
assert component . get_config ( ) . get ( " value " ) == 5
2022-05-21 08:53:27 +08:00
component = gr . Slider ( 0 , 100 , None )
2022-11-08 08:37:55 +08:00
assert component . get_config ( ) . get ( " value " ) == 0
2022-05-21 08:53:27 +08:00
2022-10-07 13:08:30 +08:00
@patch ( " gradio.Slider.get_random_value " , return_value = 7 )
def test_slider_get_random_value_on_load ( self , mock_get_random_value ) :
slider = gr . Slider ( minimum = - 5 , maximum = 10 , randomize = True )
assert slider . value == 7
2022-12-16 06:07:44 +08:00
assert slider . load_event_to_attach [ 0 ] ( ) == 7
assert slider . load_event_to_attach [ 1 ] is None
2022-10-07 13:08:30 +08:00
@patch ( " random.randint " , return_value = 3 )
def test_slider_rounds_when_using_default_randomizer ( self , mock_randint ) :
slider = gr . Slider ( minimum = 0 , maximum = 1 , randomize = True , step = 0.1 )
# If get_random_value didn't round, this test would fail
# because 0.30000000000000004 != 0.3
assert slider . get_random_value ( ) == 0.3
mock_randint . assert_called ( )
2022-03-30 04:23:30 +08:00
2022-11-08 08:37:55 +08:00
class TestCheckbox :
2022-04-06 04:35:04 +08:00
def test_component_functions ( self ) :
2022-04-05 17:54:17 +08:00
"""
2023-03-17 05:22:25 +08:00
Preprocess , postprocess , serialize , get_config
2022-04-05 17:54:17 +08:00
"""
2022-03-30 04:23:30 +08:00
bool_input = gr . Checkbox ( )
2022-11-08 08:37:55 +08:00
assert bool_input . preprocess ( True )
assert bool_input . postprocess ( True )
assert bool_input . postprocess ( True )
assert bool_input . serialize ( True , True )
2022-05-11 08:11:43 +08:00
bool_input = gr . Checkbox ( value = True , label = " Check Your Input " )
2022-11-08 08:37:55 +08:00
assert bool_input . get_config ( ) == {
" value " : True ,
" name " : " checkbox " ,
" show_label " : True ,
" label " : " Check Your Input " ,
" style " : { } ,
" elem_id " : None ,
2023-03-16 05:01:53 +08:00
" elem_classes " : None ,
2022-11-08 08:37:55 +08:00
" visible " : True ,
" interactive " : None ,
" root_url " : None ,
}
2022-03-30 04:23:30 +08:00
2022-11-08 08:37:55 +08:00
@pytest.mark.asyncio
2022-08-11 06:29:14 +08:00
async def test_in_interface ( self ) :
2022-04-05 17:54:17 +08:00
"""
Interface , process , interpret
"""
2022-03-30 04:23:30 +08:00
iface = gr . Interface ( lambda x : 1 if x else 0 , " checkbox " , " number " )
2022-11-08 08:37:55 +08:00
assert iface ( True ) == 1
2022-03-30 04:23:30 +08:00
iface = gr . Interface (
lambda x : 1 if x else 0 , " checkbox " , " number " , interpretation = " default "
)
2022-08-11 06:29:14 +08:00
scores = ( await iface . interpret ( [ False ] ) ) [ 0 ] [ " interpretation " ]
2022-11-08 08:37:55 +08:00
assert scores == ( None , 1.0 )
2022-08-11 06:29:14 +08:00
scores = ( await iface . interpret ( [ True ] ) ) [ 0 ] [ " interpretation " ]
2022-11-08 08:37:55 +08:00
assert scores == ( - 1.0 , None )
2022-03-30 04:23:30 +08:00
2022-11-08 08:37:55 +08:00
class TestCheckboxGroup :
2022-04-06 04:35:04 +08:00
def test_component_functions ( self ) :
2022-04-05 17:54:17 +08:00
"""
2023-03-17 05:22:25 +08:00
Preprocess , postprocess , serialize , get_config
2022-04-05 17:54:17 +08:00
"""
2022-03-30 04:23:30 +08:00
checkboxes_input = gr . CheckboxGroup ( [ " a " , " b " , " c " ] )
2022-11-08 08:37:55 +08:00
assert checkboxes_input . preprocess ( [ " a " , " c " ] ) == [ " a " , " c " ]
assert checkboxes_input . postprocess ( [ " a " , " c " ] ) == [ " a " , " c " ]
assert checkboxes_input . serialize ( [ " a " , " c " ] , True ) == [ " a " , " c " ]
2022-03-30 04:23:30 +08:00
checkboxes_input = gr . CheckboxGroup (
2022-05-11 08:11:43 +08:00
value = [ " a " , " c " ] ,
2022-03-30 04:23:30 +08:00
choices = [ " a " , " b " , " c " ] ,
label = " Check Your Inputs " ,
)
2022-11-08 08:37:55 +08:00
assert checkboxes_input . get_config ( ) == {
" choices " : [ " a " , " b " , " c " ] ,
" value " : [ " a " , " c " ] ,
" name " : " checkboxgroup " ,
" show_label " : True ,
" label " : " Check Your Inputs " ,
" style " : { } ,
" elem_id " : None ,
2023-03-16 05:01:53 +08:00
" elem_classes " : None ,
2022-11-08 08:37:55 +08:00
" visible " : True ,
" interactive " : None ,
" root_url " : None ,
}
with pytest . raises ( ValueError ) :
gr . CheckboxGroup ( [ " a " ] , type = " unknown " )
2022-03-30 04:23:30 +08:00
2022-12-21 07:02:18 +08:00
cbox = gr . CheckboxGroup ( choices = [ " a " , " b " ] , value = " c " )
assert cbox . get_config ( ) [ " value " ] == [ " c " ]
assert cbox . postprocess ( " a " ) == [ " a " ]
2022-11-08 08:37:55 +08:00
def test_in_interface ( self ) :
2022-04-05 17:54:17 +08:00
"""
Interface , process
"""
2022-03-30 04:23:30 +08:00
checkboxes_input = gr . CheckboxGroup ( [ " a " , " b " , " c " ] )
iface = gr . Interface ( lambda x : " | " . join ( x ) , checkboxes_input , " textbox " )
2022-11-08 08:37:55 +08:00
assert iface ( [ " a " , " c " ] ) == " a|c "
assert iface ( [ ] ) == " "
2022-04-05 17:54:17 +08:00
_ = gr . CheckboxGroup ( [ " a " , " b " , " c " ] , type = " index " )
2022-03-30 04:23:30 +08:00
2022-11-08 08:37:55 +08:00
class TestRadio :
2022-04-06 04:35:04 +08:00
def test_component_functions ( self ) :
2022-04-05 17:54:17 +08:00
"""
2023-03-17 05:22:25 +08:00
Preprocess , postprocess , serialize , get_config
2022-04-05 17:54:17 +08:00
"""
2022-03-30 04:23:30 +08:00
radio_input = gr . Radio ( [ " a " , " b " , " c " ] )
2022-11-08 08:37:55 +08:00
assert radio_input . preprocess ( " c " ) == " c "
assert radio_input . postprocess ( " a " ) == " a "
assert radio_input . serialize ( " a " , True ) == " a "
2022-03-30 04:23:30 +08:00
radio_input = gr . Radio (
choices = [ " a " , " b " , " c " ] , default = " a " , label = " Pick Your One Input "
)
2022-11-08 08:37:55 +08:00
assert radio_input . get_config ( ) == {
" choices " : [ " a " , " b " , " c " ] ,
" value " : None ,
" name " : " radio " ,
" show_label " : True ,
" label " : " Pick Your One Input " ,
" style " : { } ,
" elem_id " : None ,
2023-03-16 05:01:53 +08:00
" elem_classes " : None ,
2022-11-08 08:37:55 +08:00
" visible " : True ,
" interactive " : None ,
" root_url " : None ,
}
with pytest . raises ( ValueError ) :
gr . Radio ( [ " a " , " b " ] , type = " unknown " )
2022-03-30 04:23:30 +08:00
2022-11-08 08:37:55 +08:00
@pytest.mark.asyncio
2022-08-11 06:29:14 +08:00
async def test_in_interface ( self ) :
2022-04-05 17:54:17 +08:00
"""
Interface , process , interpret
"""
2022-03-30 04:23:30 +08:00
radio_input = gr . Radio ( [ " a " , " b " , " c " ] )
iface = gr . Interface ( lambda x : 2 * x , radio_input , " textbox " )
2022-11-08 08:37:55 +08:00
assert iface ( " c " ) == " cc "
2022-03-30 04:23:30 +08:00
radio_input = gr . Radio ( [ " a " , " b " , " c " ] , type = " index " )
iface = gr . Interface (
lambda x : 2 * x , radio_input , " number " , interpretation = " default "
)
2022-11-08 08:37:55 +08:00
assert iface ( " c " ) == 4
2022-08-11 06:29:14 +08:00
scores = ( await iface . interpret ( [ " b " ] ) ) [ 0 ] [ " interpretation " ]
2022-11-08 08:37:55 +08:00
assert scores == [ - 2.0 , None , 2.0 ]
2022-03-30 04:23:30 +08:00
2023-01-05 08:13:46 +08:00
class TestDropdown :
def test_component_functions ( self ) :
"""
2023-03-17 05:22:25 +08:00
Preprocess , postprocess , serialize , get_config
2023-01-05 08:13:46 +08:00
"""
dropdown_input = gr . Dropdown ( [ " a " , " b " , " c " ] , multiselect = True )
assert dropdown_input . preprocess ( " a " ) == " a "
assert dropdown_input . postprocess ( " a " ) == " a "
2023-02-24 05:32:18 +08:00
dropdown_input_multiselect = gr . Dropdown ( [ " a " , " b " , " c " ] )
2023-01-05 08:13:46 +08:00
assert dropdown_input_multiselect . preprocess ( [ " a " , " c " ] ) == [ " a " , " c " ]
assert dropdown_input_multiselect . postprocess ( [ " a " , " c " ] ) == [ " a " , " c " ]
assert dropdown_input_multiselect . serialize ( [ " a " , " c " ] , True ) == [ " a " , " c " ]
dropdown_input_multiselect = gr . Dropdown (
value = [ " a " , " c " ] ,
choices = [ " a " , " b " , " c " ] ,
label = " Select Your Inputs " ,
2023-02-24 05:32:18 +08:00
multiselect = True ,
max_choices = 2 ,
2023-01-05 08:13:46 +08:00
)
assert dropdown_input_multiselect . get_config ( ) == {
2023-03-31 02:20:34 +08:00
" allow_custom_value " : False ,
2023-01-05 08:13:46 +08:00
" choices " : [ " a " , " b " , " c " ] ,
" value " : [ " a " , " c " ] ,
" name " : " dropdown " ,
" show_label " : True ,
" label " : " Select Your Inputs " ,
" style " : { } ,
" elem_id " : None ,
2023-03-16 05:01:53 +08:00
" elem_classes " : None ,
2023-01-05 08:13:46 +08:00
" visible " : True ,
" interactive " : None ,
" root_url " : None ,
2023-02-24 05:32:18 +08:00
" multiselect " : True ,
" max_choices " : 2 ,
2023-01-05 08:13:46 +08:00
}
with pytest . raises ( ValueError ) :
gr . Dropdown ( [ " a " ] , type = " unknown " )
dropdown = gr . Dropdown ( choices = [ " a " , " b " ] , value = " c " )
assert dropdown . get_config ( ) [ " value " ] == " c "
assert dropdown . postprocess ( " a " ) == " a "
def test_in_interface ( self ) :
"""
Interface , process
"""
checkboxes_input = gr . CheckboxGroup ( [ " a " , " b " , " c " ] )
iface = gr . Interface ( lambda x : " | " . join ( x ) , checkboxes_input , " textbox " )
assert iface ( [ " a " , " c " ] ) == " a|c "
assert iface ( [ ] ) == " "
_ = gr . CheckboxGroup ( [ " a " , " b " , " c " ] , type = " index " )
2022-11-08 08:37:55 +08:00
class TestImage :
2022-12-14 07:01:27 +08:00
def test_component_functions ( self ) :
2022-04-05 17:54:17 +08:00
"""
2023-03-17 05:22:25 +08:00
Preprocess , postprocess , serialize , get_config , _segment_by_slic
2022-04-05 17:54:17 +08:00
type : pil , file , filepath , numpy
"""
img = deepcopy ( media_data . BASE64_IMAGE )
2022-03-30 04:23:30 +08:00
image_input = gr . Image ( )
2022-11-08 08:37:55 +08:00
assert image_input . preprocess ( img ) . shape == ( 68 , 61 , 3 )
2022-03-30 04:23:30 +08:00
image_input = gr . Image ( shape = ( 25 , 25 ) , image_mode = " L " )
2022-11-08 08:37:55 +08:00
assert image_input . preprocess ( img ) . shape == ( 25 , 25 )
2022-03-30 04:23:30 +08:00
image_input = gr . Image ( shape = ( 30 , 10 ) , type = " pil " )
2022-11-08 08:37:55 +08:00
assert image_input . preprocess ( img ) . size == ( 30 , 10 )
assert image_input . postprocess ( " test/test_files/bus.png " ) == img
assert image_input . serialize ( " test/test_files/bus.png " ) == img
2023-03-22 00:37:24 +08:00
image_input = gr . Image ( type = " filepath " )
image_temp_filepath = image_input . preprocess ( img )
assert image_temp_filepath in image_input . temp_files
2022-03-30 04:23:30 +08:00
image_input = gr . Image (
source = " upload " , tool = " editor " , type = " pil " , label = " Upload Your Image "
)
2022-11-08 08:37:55 +08:00
assert image_input . get_config ( ) == {
2023-02-23 16:46:44 +08:00
" brush_radius " : None ,
2022-11-08 08:37:55 +08:00
" image_mode " : " RGB " ,
" shape " : None ,
" source " : " upload " ,
" tool " : " editor " ,
" name " : " image " ,
" streaming " : False ,
" show_label " : True ,
" label " : " Upload Your Image " ,
" style " : { } ,
" elem_id " : None ,
2023-03-16 05:01:53 +08:00
" elem_classes " : None ,
2022-11-08 08:37:55 +08:00
" visible " : True ,
" value " : None ,
" interactive " : None ,
" root_url " : None ,
" mirror_webcam " : True ,
}
assert image_input . preprocess ( None ) is None
2022-03-30 04:23:30 +08:00
image_input = gr . Image ( invert_colors = True )
2022-11-08 08:37:55 +08:00
assert image_input . preprocess ( img ) is not None
2022-03-30 04:23:30 +08:00
image_input . preprocess ( img )
file_image = gr . Image ( type = " filepath " )
2022-11-08 08:37:55 +08:00
assert isinstance ( file_image . preprocess ( img ) , str )
with pytest . raises ( ValueError ) :
gr . Image ( type = " unknown " )
2022-03-30 04:23:30 +08:00
image_input . shape = ( 30 , 10 )
2022-11-08 08:37:55 +08:00
assert image_input . _segment_by_slic ( img ) is not None
2022-03-30 04:23:30 +08:00
2022-04-05 17:54:17 +08:00
# Output functionalities
y_img = gr . processing_utils . decode_base64_to_image (
deepcopy ( media_data . BASE64_IMAGE )
2022-03-30 04:23:30 +08:00
)
image_output = gr . Image ( )
2022-11-08 08:37:55 +08:00
assert image_output . postprocess ( y_img ) . startswith (
"  "
2022-03-30 04:23:30 +08:00
)
2022-11-08 08:37:55 +08:00
assert image_output . postprocess ( np . array ( y_img ) ) . startswith (
"  "
2022-03-30 04:23:30 +08:00
)
2022-11-08 08:37:55 +08:00
with pytest . raises ( ValueError ) :
2022-03-30 04:23:30 +08:00
image_output . postprocess ( [ 1 , 2 , 3 ] )
image_output = gr . Image ( type = " numpy " )
2022-11-08 08:37:55 +08:00
assert image_output . postprocess ( y_img ) . startswith ( " data:image/png;base64, " )
2022-03-30 04:23:30 +08:00
2023-01-19 01:47:57 +08:00
@pytest.mark.flaky
def test_serialize_url ( self ) :
img = " https://gradio.app/assets/img/header-image.jpg "
2023-03-24 06:33:44 +08:00
expected = client_utils . encode_url_or_file_to_base64 ( img )
2023-01-19 01:47:57 +08:00
assert gr . Image ( ) . serialize ( img ) == expected
2022-12-14 07:01:27 +08:00
def test_in_interface_as_input ( self ) :
2022-04-05 17:54:17 +08:00
"""
Interface , process , interpret
type : file
interpretation : default , shap ,
"""
2022-11-08 08:37:55 +08:00
img = " test/test_files/bus.png "
2022-04-05 17:54:17 +08:00
image_input = gr . Image ( )
iface = gr . Interface (
lambda x : PIL . Image . open ( x ) . rotate ( 90 , expand = True ) ,
2022-12-16 04:37:09 +08:00
gr . Image ( shape = ( 30 , 10 ) , type = " filepath " ) ,
2022-04-05 17:54:17 +08:00
" image " ,
)
2022-11-08 08:37:55 +08:00
output = iface ( img )
assert PIL . Image . open ( output ) . size == ( 10 , 30 )
2022-04-05 17:54:17 +08:00
iface = gr . Interface (
lambda x : np . sum ( x ) , image_input , " number " , interpretation = " default "
)
2022-12-14 07:01:27 +08:00
def test_in_interface_as_output ( self ) :
2022-04-05 17:54:17 +08:00
"""
Interface , process
"""
2022-03-30 04:23:30 +08:00
def generate_noise ( width , height ) :
return np . random . randint ( 0 , 256 , ( width , height , 3 ) )
iface = gr . Interface ( generate_noise , [ " slider " , " slider " ] , " image " )
2022-11-08 08:37:55 +08:00
assert iface ( 10 , 20 ) . endswith ( " .png " )
2022-03-30 04:23:30 +08:00
2022-05-21 08:53:27 +08:00
def test_static ( self ) :
"""
postprocess
"""
component = gr . Image ( " test/test_files/bus.png " )
2022-11-08 08:37:55 +08:00
assert component . get_config ( ) . get ( " value " ) == media_data . BASE64_IMAGE
2022-05-21 08:53:27 +08:00
component = gr . Image ( None )
2022-11-08 08:37:55 +08:00
assert component . get_config ( ) . get ( " value " ) is None
2022-05-21 08:53:27 +08:00
2022-03-30 04:23:30 +08:00
2022-11-08 08:37:55 +08:00
class TestPlot :
@pytest.mark.asyncio
2022-08-11 06:29:14 +08:00
async def test_in_interface_as_output ( self ) :
2022-05-23 04:54:04 +08:00
"""
Interface , process
"""
def plot ( num ) :
fig = plt . figure ( )
plt . plot ( range ( num ) , range ( num ) )
return fig
iface = gr . Interface ( plot , " slider " , " plot " )
2022-12-28 05:54:47 +08:00
output = await iface . process_api ( fn_index = 0 , inputs = [ 10 ] , state = { } )
2022-11-08 08:37:55 +08:00
assert output [ " data " ] [ 0 ] [ " type " ] == " matplotlib "
assert output [ " data " ] [ 0 ] [ " plot " ] . startswith ( " data:image/png;base64 " )
2022-05-23 04:54:04 +08:00
def test_static ( self ) :
"""
postprocess
"""
fig = plt . figure ( )
plt . plot ( [ 1 , 2 , 3 ] , [ 1 , 2 , 3 ] )
component = gr . Plot ( fig )
2022-11-08 08:37:55 +08:00
assert component . get_config ( ) . get ( " value " ) is not None
2022-05-23 04:54:04 +08:00
component = gr . Plot ( None )
2022-11-08 08:37:55 +08:00
assert component . get_config ( ) . get ( " value " ) is None
2022-05-23 04:54:04 +08:00
2022-12-03 01:53:42 +08:00
def test_postprocess_altair ( self ) :
import altair as alt
from vega_datasets import data
cars = data . cars ( )
chart = (
alt . Chart ( cars )
. mark_point ( )
. encode (
x = " Horsepower " ,
y = " Miles_per_Gallon " ,
color = " Origin " ,
)
)
out = gr . Plot ( ) . postprocess ( chart )
assert isinstance ( out [ " plot " ] , str )
assert out [ " plot " ] == chart . to_json ( )
2022-05-23 04:54:04 +08:00
2022-11-08 08:37:55 +08:00
class TestAudio :
2022-04-06 04:35:04 +08:00
def test_component_functions ( self ) :
2022-04-05 17:54:17 +08:00
"""
2023-03-17 05:22:25 +08:00
Preprocess , postprocess serialize , get_config , deserialize
2022-04-05 17:54:17 +08:00
type : filepath , numpy , file
"""
x_wav = deepcopy ( media_data . BASE64_AUDIO )
2022-03-30 04:23:30 +08:00
audio_input = gr . Audio ( )
2023-02-16 07:24:48 +08:00
output1 = audio_input . preprocess ( x_wav )
assert output1 [ 0 ] == 8000
assert output1 [ 1 ] . shape == ( 8046 , )
2023-03-09 04:24:09 +08:00
x_wav [ " is_file " ] = True
audio_input = gr . Audio ( type = " filepath " )
output1 = audio_input . preprocess ( x_wav )
assert Path ( output1 ) . name == " audio_sample-0-100.wav "
2022-08-23 23:31:04 +08:00
assert filecmp . cmp (
" test/test_files/audio_sample.wav " ,
audio_input . serialize ( " test/test_files/audio_sample.wav " ) [ " name " ] ,
2022-03-30 04:23:30 +08:00
)
audio_input = gr . Audio ( label = " Upload Your Audio " )
2022-11-08 08:37:55 +08:00
assert audio_input . get_config ( ) == {
" source " : " upload " ,
" name " : " audio " ,
" streaming " : False ,
" show_label " : True ,
" label " : " Upload Your Audio " ,
" style " : { } ,
" elem_id " : None ,
2023-03-16 05:01:53 +08:00
" elem_classes " : None ,
2022-11-08 08:37:55 +08:00
" visible " : True ,
" value " : None ,
" interactive " : None ,
" root_url " : None ,
}
assert audio_input . preprocess ( None ) is None
2022-03-30 04:23:30 +08:00
x_wav [ " is_example " ] = True
x_wav [ " crop_min " ] , x_wav [ " crop_max " ] = 1 , 4
2023-02-16 07:24:48 +08:00
output2 = audio_input . preprocess ( x_wav )
assert output2 is not None
assert output1 != output2
2022-08-23 23:31:04 +08:00
2022-03-30 04:23:30 +08:00
audio_input = gr . Audio ( type = " filepath " )
2022-11-08 08:37:55 +08:00
assert isinstance ( audio_input . preprocess ( x_wav ) , str )
with pytest . raises ( ValueError ) :
gr . Audio ( type = " unknown " )
2022-03-30 04:23:30 +08:00
2022-04-05 17:54:17 +08:00
# Output functionalities
2023-03-24 06:33:44 +08:00
y_audio = client_utils . decode_base64_to_file (
2022-04-05 17:54:17 +08:00
deepcopy ( media_data . BASE64_AUDIO ) [ " data " ]
2022-03-30 04:23:30 +08:00
)
2022-12-16 04:37:09 +08:00
audio_output = gr . Audio ( type = " filepath " )
2022-11-08 08:37:55 +08:00
assert filecmp . cmp ( y_audio . name , audio_output . postprocess ( y_audio . name ) [ " name " ] )
assert audio_output . get_config ( ) == {
" name " : " audio " ,
" streaming " : False ,
" show_label " : True ,
" label " : None ,
" source " : " upload " ,
" style " : { } ,
" elem_id " : None ,
2023-03-16 05:01:53 +08:00
" elem_classes " : None ,
2022-11-08 08:37:55 +08:00
" visible " : True ,
" value " : None ,
" interactive " : None ,
" root_url " : None ,
}
assert audio_output . deserialize (
2022-03-30 04:23:30 +08:00
{
2022-11-08 08:37:55 +08:00
" name " : None ,
" data " : deepcopy ( media_data . BASE64_AUDIO ) [ " data " ] ,
" is_file " : False ,
}
) . endswith ( " .wav " )
2022-03-30 04:23:30 +08:00
2022-12-16 04:37:09 +08:00
output1 = audio_output . postprocess ( y_audio . name )
output2 = audio_output . postprocess ( y_audio . name )
assert output1 == output2
2023-01-18 03:29:27 +08:00
def test_serialize ( self ) :
audio_input = gr . Audio ( )
2023-03-08 06:30:04 +08:00
serialized_input = audio_input . serialize ( " test/test_files/audio_sample.wav " )
assert serialized_input [ " data " ] == media_data . BASE64_AUDIO [ " data " ]
assert os . path . basename ( serialized_input [ " name " ] ) == " audio_sample.wav "
assert serialized_input [ " orig_name " ] == " audio_sample.wav "
assert not serialized_input [ " is_file " ]
2023-01-18 03:29:27 +08:00
2022-04-05 17:54:17 +08:00
def test_tokenize ( self ) :
"""
Tokenize , get_masked_inputs
"""
x_wav = deepcopy ( media_data . BASE64_AUDIO )
audio_input = gr . Audio ( )
tokens , _ , _ = audio_input . tokenize ( x_wav )
2022-11-08 08:37:55 +08:00
assert len ( tokens ) == audio_input . interpretation_segments
2022-04-05 17:54:17 +08:00
x_new = audio_input . get_masked_inputs ( tokens , [ [ 1 ] * len ( tokens ) ] ) [ 0 ]
similarity = SequenceMatcher ( a = x_wav [ " data " ] , b = x_new ) . ratio ( )
2022-11-08 08:37:55 +08:00
assert similarity > 0.9
2022-04-05 17:54:17 +08:00
2022-12-14 07:01:27 +08:00
def test_in_interface ( self ) :
2022-04-07 06:40:28 +08:00
def reverse_audio ( audio ) :
sr , data = audio
return ( sr , np . flipud ( data ) )
iface = gr . Interface ( reverse_audio , " audio " , " audio " )
2022-11-08 08:37:55 +08:00
reversed_file = iface ( " test/test_files/audio_sample.wav " )
reversed_reversed_file = iface ( reversed_file )
2023-03-24 06:33:44 +08:00
reversed_reversed_data = client_utils . encode_url_or_file_to_base64 (
2022-11-08 08:37:55 +08:00
reversed_reversed_file
2022-04-07 06:40:28 +08:00
)
similarity = SequenceMatcher (
a = reversed_reversed_data , b = media_data . BASE64_AUDIO [ " data " ]
) . ratio ( )
2022-11-08 08:37:55 +08:00
assert similarity > 0.99
2022-04-07 06:40:28 +08:00
2022-12-14 07:01:27 +08:00
def test_in_interface_as_output ( self ) :
2022-04-05 17:54:17 +08:00
"""
Interface , process
"""
2022-03-30 04:23:30 +08:00
def generate_noise ( duration ) :
return 48000 , np . random . randint ( - 256 , 256 , ( duration , 3 ) ) . astype ( np . int16 )
iface = gr . Interface ( generate_noise , " slider " , " audio " )
2022-11-08 08:37:55 +08:00
assert iface ( 100 ) . endswith ( " .wav " )
2022-03-30 04:23:30 +08:00
2022-10-07 13:08:30 +08:00
def test_audio_preprocess_can_be_read_by_scipy ( self ) :
x_wav = deepcopy ( media_data . BASE64_MICROPHONE )
audio_input = gr . Audio ( type = " filepath " )
output = audio_input . preprocess ( x_wav )
wavfile . read ( output )
2022-03-30 04:23:30 +08:00
2022-11-08 08:37:55 +08:00
class TestFile :
2022-04-06 04:35:04 +08:00
def test_component_functions ( self ) :
2022-04-05 17:54:17 +08:00
"""
2023-03-17 05:22:25 +08:00
Preprocess , serialize , get_config , value
2022-04-05 17:54:17 +08:00
"""
x_file = deepcopy ( media_data . BASE64_FILE )
2022-03-30 04:23:30 +08:00
file_input = gr . File ( )
output = file_input . preprocess ( x_file )
2022-11-08 08:37:55 +08:00
assert isinstance ( output , tempfile . _TemporaryFileWrapper )
2022-09-06 22:40:59 +08:00
serialized = file_input . serialize ( " test/test_files/sample_file.pdf " )
2022-08-23 23:31:04 +08:00
assert filecmp . cmp (
2022-09-06 22:40:59 +08:00
serialized [ " name " ] ,
2022-03-30 04:23:30 +08:00
" test/test_files/sample_file.pdf " ,
)
2022-09-06 22:40:59 +08:00
assert serialized [ " orig_name " ] == " sample_file.pdf "
2022-08-30 04:39:15 +08:00
assert output . orig_name == " test/test_files/sample_file.pdf "
2022-03-30 04:23:30 +08:00
2022-12-16 04:37:09 +08:00
x_file [ " is_file " ] = True
input1 = file_input . preprocess ( x_file )
input2 = file_input . preprocess ( x_file )
assert input1 . name == input2 . name
2023-03-09 04:24:09 +08:00
assert Path ( input1 . name ) . name == " sample_file.pdf "
2022-12-16 04:37:09 +08:00
2022-03-30 04:23:30 +08:00
file_input = gr . File ( label = " Upload Your File " )
2022-11-08 08:37:55 +08:00
assert file_input . get_config ( ) == {
" file_count " : " single " ,
2022-11-22 02:34:31 +08:00
" file_types " : None ,
2022-11-08 08:37:55 +08:00
" name " : " file " ,
" show_label " : True ,
" label " : " Upload Your File " ,
" style " : { } ,
" elem_id " : None ,
2023-03-16 05:01:53 +08:00
" elem_classes " : None ,
2022-11-08 08:37:55 +08:00
" visible " : True ,
" value " : None ,
" interactive " : None ,
" root_url " : None ,
2023-03-14 08:12:41 +08:00
" selectable " : False ,
2022-11-08 08:37:55 +08:00
}
assert file_input . preprocess ( None ) is None
2022-03-30 04:23:30 +08:00
x_file [ " is_example " ] = True
2022-11-08 08:37:55 +08:00
assert file_input . preprocess ( x_file ) is not None
2022-03-30 04:23:30 +08:00
2023-01-16 12:38:28 +08:00
zero_size_file = { " name " : " document.txt " , " size " : 0 , " data " : " data: " }
temp_file = file_input . preprocess ( zero_size_file )
assert os . stat ( temp_file . name ) . st_size == 0
2022-11-28 10:12:58 +08:00
file_input = gr . File ( type = " binary " )
output = file_input . preprocess ( x_file )
assert type ( output ) == bytes
2022-12-16 04:37:09 +08:00
output1 = file_input . postprocess ( " test/test_files/sample_file.pdf " )
output2 = file_input . postprocess ( " test/test_files/sample_file.pdf " )
assert output1 == output2
2023-01-25 17:32:22 +08:00
def test_file_type_must_be_list ( self ) :
with pytest . raises (
ValueError , match = " Parameter file_types must be a list. Received str "
) :
gr . File ( file_types = " .json " )
2022-11-08 08:37:55 +08:00
def test_in_interface_as_input ( self ) :
2022-04-05 17:54:17 +08:00
"""
Interface , process
"""
2022-11-08 08:37:55 +08:00
x_file = media_data . BASE64_FILE [ " name " ]
2022-03-30 04:23:30 +08:00
def get_size_of_file ( file_obj ) :
return os . path . getsize ( file_obj . name )
iface = gr . Interface ( get_size_of_file , " file " , " number " )
2022-11-08 08:37:55 +08:00
assert iface ( x_file ) == 10558
2022-03-30 04:23:30 +08:00
2022-11-28 10:12:58 +08:00
def test_as_component_as_output ( self ) :
2022-04-05 17:54:17 +08:00
"""
2022-08-23 23:31:04 +08:00
Interface , process
2022-04-05 17:54:17 +08:00
"""
2022-03-30 04:23:30 +08:00
def write_file ( content ) :
with open ( " test.txt " , " w " ) as f :
f . write ( content )
return " test.txt "
iface = gr . Interface ( write_file , " text " , " file " )
2022-11-08 08:37:55 +08:00
assert iface ( " hello world " ) . endswith ( " .txt " )
2022-03-30 04:23:30 +08:00
2022-12-16 04:37:09 +08:00
class TestUploadButton :
def test_component_functions ( self ) :
"""
preprocess
"""
x_file = deepcopy ( media_data . BASE64_FILE )
upload_input = gr . UploadButton ( )
input = upload_input . preprocess ( x_file )
assert isinstance ( input , tempfile . _TemporaryFileWrapper )
x_file [ " is_file " ] = True
input1 = upload_input . preprocess ( x_file )
input2 = upload_input . preprocess ( x_file )
assert input1 . name == input2 . name
2023-01-25 17:32:22 +08:00
def test_raises_if_file_types_is_not_list ( self ) :
with pytest . raises (
ValueError , match = " Parameter file_types must be a list. Received int "
) :
gr . UploadButton ( file_types = 2 )
2022-12-16 04:37:09 +08:00
2022-11-08 08:37:55 +08:00
class TestDataframe :
2022-04-06 04:35:04 +08:00
def test_component_functions ( self ) :
2022-04-05 17:54:17 +08:00
"""
2023-03-17 05:22:25 +08:00
Preprocess , serialize , get_config
2022-04-05 17:54:17 +08:00
"""
2022-07-12 18:35:20 +08:00
x_data = {
" data " : [ [ " Tim " , 12 , False ] , [ " Jan " , 24 , True ] ] ,
" headers " : [ " Name " , " Age " , " Member " ] ,
}
2022-03-30 04:23:30 +08:00
dataframe_input = gr . Dataframe ( headers = [ " Name " , " Age " , " Member " ] )
output = dataframe_input . preprocess ( x_data )
2022-11-08 08:37:55 +08:00
assert output [ " Age " ] [ 1 ] == 24
assert not output [ " Member " ] [ 0 ]
assert dataframe_input . postprocess ( x_data ) == x_data
2022-03-30 04:23:30 +08:00
dataframe_input = gr . Dataframe (
headers = [ " Name " , " Age " , " Member " ] , label = " Dataframe Input "
)
2022-11-08 08:37:55 +08:00
assert dataframe_input . get_config ( ) == {
" headers " : [ " Name " , " Age " , " Member " ] ,
" datatype " : [ " str " , " str " , " str " ] ,
" row_count " : ( 1 , " dynamic " ) ,
" col_count " : ( 3 , " dynamic " ) ,
" value " : {
" data " : [
[ " " , " " , " " ] ,
] ,
2022-03-30 04:23:30 +08:00
" headers " : [ " Name " , " Age " , " Member " ] ,
} ,
2022-11-08 08:37:55 +08:00
" name " : " dataframe " ,
" show_label " : True ,
" label " : " Dataframe Input " ,
" max_rows " : 20 ,
" max_cols " : None ,
" overflow_row_behaviour " : " paginate " ,
" style " : { } ,
" elem_id " : None ,
2023-03-16 05:01:53 +08:00
" elem_classes " : None ,
2022-11-08 08:37:55 +08:00
" visible " : True ,
" interactive " : None ,
" root_url " : None ,
" wrap " : False ,
}
2022-03-30 04:23:30 +08:00
dataframe_input = gr . Dataframe ( )
output = dataframe_input . preprocess ( x_data )
2022-11-08 08:37:55 +08:00
assert output [ " Age " ] [ 1 ] == 24
with pytest . raises ( ValueError ) :
gr . Dataframe ( type = " unknown " )
2022-03-30 04:23:30 +08:00
dataframe_output = gr . Dataframe ( )
2022-11-08 08:37:55 +08:00
assert dataframe_output . get_config ( ) == {
" headers " : [ 1 , 2 , 3 ] ,
" max_rows " : 20 ,
" max_cols " : None ,
" overflow_row_behaviour " : " paginate " ,
" name " : " dataframe " ,
" show_label " : True ,
" label " : None ,
" style " : { } ,
" elem_id " : None ,
2023-03-16 05:01:53 +08:00
" elem_classes " : None ,
2022-11-08 08:37:55 +08:00
" visible " : True ,
" datatype " : [ " str " , " str " , " str " ] ,
" row_count " : ( 1 , " dynamic " ) ,
" col_count " : ( 3 , " dynamic " ) ,
" value " : {
" data " : [
[ " " , " " , " " ] ,
] ,
2022-08-02 06:57:31 +08:00
" headers " : [ 1 , 2 , 3 ] ,
2022-03-30 04:23:30 +08:00
} ,
2022-11-08 08:37:55 +08:00
" interactive " : None ,
" root_url " : None ,
" wrap " : False ,
}
2022-08-02 06:57:31 +08:00
def test_postprocess ( self ) :
"""
postprocess
"""
dataframe_output = gr . Dataframe ( )
2023-03-28 07:55:11 +08:00
output = dataframe_output . postprocess ( [ ] )
assert output == { " data " : [ [ ] ] , " headers " : [ ] }
2022-08-02 06:57:31 +08:00
output = dataframe_output . postprocess ( np . zeros ( ( 2 , 2 ) ) )
2022-11-08 08:37:55 +08:00
assert output == { " data " : [ [ 0 , 0 ] , [ 0 , 0 ] ] , " headers " : [ 1 , 2 ] }
2022-08-02 06:57:31 +08:00
output = dataframe_output . postprocess ( [ [ 1 , 3 , 5 ] ] )
2022-11-08 08:37:55 +08:00
assert output == { " data " : [ [ 1 , 3 , 5 ] ] , " headers " : [ 1 , 2 , 3 ] }
2022-08-02 06:57:31 +08:00
output = dataframe_output . postprocess (
pd . DataFrame ( [ [ 2 , True ] , [ 3 , True ] , [ 4 , False ] ] , columns = [ " num " , " prime " ] )
)
2022-11-08 08:37:55 +08:00
assert output == {
" headers " : [ " num " , " prime " ] ,
" data " : [ [ 2 , True ] , [ 3 , True ] , [ 4 , False ] ] ,
}
with pytest . raises ( ValueError ) :
gr . Dataframe ( type = " unknown " )
2022-03-30 04:23:30 +08:00
2022-08-02 06:57:31 +08:00
# When the headers don't match the data
dataframe_output = gr . Dataframe ( headers = [ " one " , " two " , " three " ] )
output = dataframe_output . postprocess ( [ [ 2 , True ] , [ 3 , True ] ] )
2022-11-08 08:37:55 +08:00
assert output == {
" headers " : [ " one " , " two " ] ,
" data " : [ [ 2 , True ] , [ 3 , True ] ] ,
}
2022-08-02 06:57:31 +08:00
dataframe_output = gr . Dataframe ( headers = [ " one " , " two " , " three " ] )
output = dataframe_output . postprocess ( [ [ 2 , True , " ab " , 4 ] , [ 3 , True , " cd " , 5 ] ] )
2022-11-08 08:37:55 +08:00
assert output == {
" headers " : [ " one " , " two " , " three " , 4 ] ,
" data " : [ [ 2 , True , " ab " , 4 ] , [ 3 , True , " cd " , 5 ] ] ,
}
2022-03-30 04:23:30 +08:00
2022-10-07 13:08:30 +08:00
def test_dataframe_postprocess_all_types ( self ) :
df = pd . DataFrame (
{
" date_1 " : pd . date_range ( " 2021-01-01 " , periods = 2 ) ,
" date_2 " : pd . date_range ( " 2022-02-15 " , periods = 2 ) . strftime (
" % B %d , % Y, %r "
) ,
" number " : np . array ( [ 0.2233 , 0.57281 ] ) ,
2022-12-14 07:01:27 +08:00
" number_2 " : np . array ( [ 84 , 23 ] ) . astype ( np . int64 ) ,
2022-10-07 13:08:30 +08:00
" bool " : [ True , False ] ,
" markdown " : [ " # Hello " , " # Goodbye " ] ,
}
)
component = gr . Dataframe (
datatype = [ " date " , " date " , " number " , " number " , " bool " , " markdown " ]
)
output = component . postprocess ( df )
assert output == {
" headers " : list ( df . columns ) ,
" data " : [
[
pd . Timestamp ( " 2021-01-01 00:00:00 " ) ,
" February 15, 2022, 12:00:00 AM " ,
0.2233 ,
84 ,
True ,
" <h1>Hello</h1> \n " ,
] ,
[
pd . Timestamp ( " 2021-01-02 00:00:00 " ) ,
" February 16, 2022, 12:00:00 AM " ,
0.57281 ,
23 ,
False ,
" <h1>Goodbye</h1> \n " ,
] ,
] ,
}
def test_dataframe_postprocess_only_dates ( self ) :
df = pd . DataFrame (
{
" date_1 " : pd . date_range ( " 2021-01-01 " , periods = 2 ) ,
" date_2 " : pd . date_range ( " 2022-02-15 " , periods = 2 ) ,
}
)
component = gr . Dataframe ( datatype = [ " date " , " date " ] )
output = component . postprocess ( df )
assert output == {
" headers " : list ( df . columns ) ,
" data " : [
[
pd . Timestamp ( " 2021-01-01 00:00:00 " ) ,
pd . Timestamp ( " 2022-02-15 00:00:00 " ) ,
] ,
[
pd . Timestamp ( " 2021-01-02 00:00:00 " ) ,
pd . Timestamp ( " 2022-02-16 00:00:00 " ) ,
] ,
] ,
}
2022-03-30 04:23:30 +08:00
2022-10-13 04:48:53 +08:00
class TestDataset :
def test_preprocessing ( self ) :
2023-03-09 04:24:09 +08:00
test_file_dir = Path ( __file__ ) . parent / " test_files "
bus = str ( Path ( test_file_dir , " bus.png " ) . resolve ( ) )
2022-10-13 04:48:53 +08:00
dataset = gr . Dataset (
components = [ " number " , " textbox " , " image " , " html " , " markdown " ] ,
samples = [
[ 5 , " hello " , bus , " <b>Bold</b> " , " **Bold** " ] ,
[ 15 , " hi " , bus , " <i>Italics</i> " , " *Italics* " ] ,
] ,
)
assert dataset . preprocess ( 1 ) == [
15 ,
" hi " ,
bus ,
" <i>Italics</i> " ,
" <p><em>Italics</em></p> \n " ,
]
dataset = gr . Dataset (
components = [ " number " , " textbox " , " image " , " html " , " markdown " ] ,
samples = [
[ 5 , " hello " , bus , " <b>Bold</b> " , " **Bold** " ] ,
[ 15 , " hi " , bus , " <i>Italics</i> " , " *Italics* " ] ,
] ,
type = " index " ,
)
assert dataset . preprocess ( 1 ) == 1
def test_postprocessing ( self ) :
2023-03-09 04:24:09 +08:00
test_file_dir = Path ( Path ( __file__ ) . parent , " test_files " )
bus = Path ( test_file_dir , " bus.png " )
2022-10-13 04:48:53 +08:00
dataset = gr . Dataset (
components = [ " number " , " textbox " , " image " , " html " , " markdown " ] , type = " index "
)
output = dataset . postprocess (
samples = [
[ 5 , " hello " , bus , " <b>Bold</b> " , " **Bold** " ] ,
[ 15 , " hi " , bus , " <i>Italics</i> " , " *Italics* " ] ,
] ,
)
assert output == {
" samples " : [
[ 5 , " hello " , bus , " <b>Bold</b> " , " **Bold** " ] ,
[ 15 , " hi " , bus , " <i>Italics</i> " , " *Italics* " ] ,
] ,
" __type__ " : " update " ,
}
2022-11-08 08:37:55 +08:00
class TestVideo :
2022-04-06 04:35:04 +08:00
def test_component_functions ( self ) :
2022-04-05 17:54:17 +08:00
"""
2023-03-17 05:22:25 +08:00
Preprocess , serialize , deserialize , get_config
2022-04-05 17:54:17 +08:00
"""
x_video = deepcopy ( media_data . BASE64_VIDEO )
2022-03-30 04:23:30 +08:00
video_input = gr . Video ( )
2022-12-16 04:37:09 +08:00
output1 = video_input . preprocess ( x_video )
assert isinstance ( output1 , str )
output2 = video_input . preprocess ( x_video )
assert output1 == output2
2022-03-30 04:23:30 +08:00
video_input = gr . Video ( label = " Upload Your Video " )
2022-11-08 08:37:55 +08:00
assert video_input . get_config ( ) == {
" source " : " upload " ,
" name " : " video " ,
" show_label " : True ,
" label " : " Upload Your Video " ,
" style " : { } ,
" elem_id " : None ,
2023-03-16 05:01:53 +08:00
" elem_classes " : None ,
2022-11-08 08:37:55 +08:00
" visible " : True ,
" value " : None ,
" interactive " : None ,
" root_url " : None ,
" mirror_webcam " : True ,
2022-12-27 03:53:07 +08:00
" include_audio " : True ,
2022-11-08 08:37:55 +08:00
}
assert video_input . preprocess ( None ) is None
2022-03-30 04:23:30 +08:00
x_video [ " is_example " ] = True
2022-11-08 08:37:55 +08:00
assert video_input . preprocess ( x_video ) is not None
2022-05-14 10:48:46 +08:00
video_input = gr . Video ( format = " avi " )
2022-10-18 23:07:03 +08:00
output_video = video_input . preprocess ( x_video )
2022-11-08 08:37:55 +08:00
assert output_video [ - 3 : ] == " avi "
2022-10-18 23:07:03 +08:00
assert " flip " not in output_video
2022-06-03 06:37:25 +08:00
2022-08-23 23:31:04 +08:00
assert filecmp . cmp (
video_input . serialize ( x_video [ " name " ] ) [ " name " ] , x_video [ " name " ]
2022-06-03 06:37:25 +08:00
)
2022-03-30 04:23:30 +08:00
2022-04-05 17:54:17 +08:00
# Output functionalities
y_vid_path = " test/test_files/video_sample.mp4 "
2022-03-30 04:23:30 +08:00
video_output = gr . Video ( )
2022-12-16 04:37:09 +08:00
output1 = video_output . postprocess ( y_vid_path ) [ " name " ]
assert output1 . endswith ( " mp4 " )
output2 = video_output . postprocess ( y_vid_path ) [ " name " ]
assert output1 == output2
2023-03-30 08:34:04 +08:00
assert video_output . postprocess ( y_vid_path ) [ " orig_name " ] == " video_sample.mp4 "
2022-12-16 04:37:09 +08:00
2022-11-08 08:37:55 +08:00
assert video_output . deserialize (
{
" name " : None ,
" data " : deepcopy ( media_data . BASE64_VIDEO ) [ " data " ] ,
" is_file " : False ,
}
) . endswith ( " .mp4 " )
2022-03-30 04:23:30 +08:00
2022-12-14 07:01:27 +08:00
def test_in_interface ( self ) :
2022-04-05 17:54:17 +08:00
"""
Interface , process
"""
2022-11-08 08:37:55 +08:00
x_video = media_data . BASE64_VIDEO [ " name " ]
2022-04-05 17:54:17 +08:00
iface = gr . Interface ( lambda x : x , " video " , " playable_video " )
2022-11-08 08:37:55 +08:00
assert iface ( x_video ) . endswith ( " .mp4 " )
2022-04-05 17:54:17 +08:00
2022-12-16 02:35:22 +08:00
def test_with_waveform ( self ) :
"""
Interface , process
"""
x_audio = media_data . BASE64_AUDIO [ " name " ]
iface = gr . Interface ( lambda x : gr . make_waveform ( x ) , " audio " , " video " )
assert iface ( x_audio ) . endswith ( " .mp4 " )
2022-10-07 13:08:30 +08:00
def test_video_postprocess_converts_to_playable_format ( self ) :
2023-03-09 04:24:09 +08:00
test_file_dir = Path ( Path ( __file__ ) . parent , " test_files " )
2022-10-07 13:08:30 +08:00
# This file has a playable container but not playable codec
with tempfile . NamedTemporaryFile (
2022-12-16 04:37:09 +08:00
suffix = " bad_video.mp4 " , delete = False
2022-10-07 13:08:30 +08:00
) as tmp_not_playable_vid :
bad_vid = str ( test_file_dir / " bad_video_sample.mp4 " )
assert not processing_utils . video_is_playable ( bad_vid )
shutil . copy ( bad_vid , tmp_not_playable_vid . name )
_ = gr . Video ( ) . postprocess ( tmp_not_playable_vid . name )
# The original video gets converted to .mp4 format
2023-03-09 04:24:09 +08:00
full_path_to_output = Path ( tmp_not_playable_vid . name ) . with_suffix ( " .mp4 " )
2022-10-07 13:08:30 +08:00
assert processing_utils . video_is_playable ( str ( full_path_to_output ) )
# This file has a playable codec but not a playable container
with tempfile . NamedTemporaryFile (
2022-12-16 04:37:09 +08:00
suffix = " playable_but_bad_container.mkv " , delete = False
2022-10-07 13:08:30 +08:00
) as tmp_not_playable_vid :
bad_vid = str ( test_file_dir / " playable_but_bad_container.mkv " )
assert not processing_utils . video_is_playable ( bad_vid )
shutil . copy ( bad_vid , tmp_not_playable_vid . name )
_ = gr . Video ( ) . postprocess ( tmp_not_playable_vid . name )
2023-03-09 04:24:09 +08:00
full_path_to_output = Path ( tmp_not_playable_vid . name ) . with_suffix ( " .mp4 " )
2022-10-07 13:08:30 +08:00
assert processing_utils . video_is_playable ( str ( full_path_to_output ) )
2022-12-28 05:54:47 +08:00
@patch ( " pathlib.Path.exists " , MagicMock ( return_value = False ) )
2022-10-18 23:07:03 +08:00
@patch ( " gradio.components.FFmpeg " )
def test_video_preprocessing_flips_video_for_webcam ( self , mock_ffmpeg ) :
2022-12-16 04:37:09 +08:00
# Ensures that the cached temp video file is not used so that ffmpeg is called for each test
2022-10-18 23:07:03 +08:00
x_video = deepcopy ( media_data . BASE64_VIDEO )
video_input = gr . Video ( source = " webcam " )
_ = video_input . preprocess ( x_video )
# Dict mapping filename to FFmpeg options
output_params = mock_ffmpeg . call_args_list [ 0 ] [ 1 ] [ " outputs " ]
assert " hflip " in list ( output_params . values ( ) ) [ 0 ]
assert " flip " in list ( output_params . keys ( ) ) [ 0 ]
mock_ffmpeg . reset_mock ( )
2022-12-27 03:53:07 +08:00
_ = gr . Video (
source = " webcam " , mirror_webcam = False , include_audio = True
) . preprocess ( x_video )
2022-10-18 23:07:03 +08:00
mock_ffmpeg . assert_not_called ( )
mock_ffmpeg . reset_mock ( )
2022-12-27 03:53:07 +08:00
_ = gr . Video ( source = " upload " , format = " mp4 " , include_audio = True ) . preprocess (
x_video
)
2022-10-18 23:07:03 +08:00
mock_ffmpeg . assert_not_called ( )
mock_ffmpeg . reset_mock ( )
output_file = gr . Video (
source = " webcam " , mirror_webcam = True , format = " avi "
) . preprocess ( x_video )
output_params = mock_ffmpeg . call_args_list [ 0 ] [ 1 ] [ " outputs " ]
assert " hflip " in list ( output_params . values ( ) ) [ 0 ]
assert " flip " in list ( output_params . keys ( ) ) [ 0 ]
assert " .avi " in list ( output_params . keys ( ) ) [ 0 ]
assert " .avi " in output_file
mock_ffmpeg . reset_mock ( )
output_file = gr . Video (
2022-12-27 03:53:07 +08:00
source = " webcam " , mirror_webcam = False , format = " avi " , include_audio = False
2022-10-18 23:07:03 +08:00
) . preprocess ( x_video )
output_params = mock_ffmpeg . call_args_list [ 0 ] [ 1 ] [ " outputs " ]
2022-12-27 03:53:07 +08:00
assert list ( output_params . values ( ) ) [ 0 ] == [ " -an " ]
2022-10-18 23:07:03 +08:00
assert " flip " not in list ( output_params . keys ( ) ) [ 0 ]
assert " .avi " in list ( output_params . keys ( ) ) [ 0 ]
assert " .avi " in output_file
2022-03-30 04:23:30 +08:00
2022-11-08 08:37:55 +08:00
class TestTimeseries :
2022-04-06 04:35:04 +08:00
def test_component_functions ( self ) :
2022-04-05 17:54:17 +08:00
"""
2023-03-17 05:22:25 +08:00
Preprocess , postprocess , get_config ,
2022-04-05 17:54:17 +08:00
"""
2022-03-30 04:23:30 +08:00
timeseries_input = gr . Timeseries ( x = " time " , y = [ " retail " , " food " , " other " ] )
x_timeseries = {
" data " : [ [ 1 ] + [ 2 ] * len ( timeseries_input . y ) ] * 4 ,
" headers " : [ timeseries_input . x ] + timeseries_input . y ,
}
output = timeseries_input . preprocess ( x_timeseries )
2022-11-08 08:37:55 +08:00
assert isinstance ( output , pd . core . frame . DataFrame )
2022-03-30 04:23:30 +08:00
timeseries_input = gr . Timeseries (
x = " time " , y = " retail " , label = " Upload Your Timeseries "
)
2022-11-08 08:37:55 +08:00
assert timeseries_input . get_config ( ) == {
" x " : " time " ,
" y " : [ " retail " ] ,
" name " : " timeseries " ,
" show_label " : True ,
" label " : " Upload Your Timeseries " ,
" colors " : None ,
" style " : { } ,
" elem_id " : None ,
2023-03-16 05:01:53 +08:00
" elem_classes " : None ,
2022-11-08 08:37:55 +08:00
" visible " : True ,
" value " : None ,
" interactive " : None ,
" root_url " : None ,
}
assert timeseries_input . preprocess ( None ) is None
2022-03-30 04:23:30 +08:00
x_timeseries [ " range " ] = ( 0 , 1 )
2022-11-08 08:37:55 +08:00
assert timeseries_input . preprocess ( x_timeseries ) is not None
2022-03-30 04:23:30 +08:00
2022-04-05 17:54:17 +08:00
# Output functionalities
2022-03-30 04:23:30 +08:00
timeseries_output = gr . Timeseries ( label = " Disease " )
2022-11-08 08:37:55 +08:00
assert timeseries_output . get_config ( ) == {
" x " : None ,
2022-12-30 03:21:29 +08:00
" y " : None ,
2022-11-08 08:37:55 +08:00
" name " : " timeseries " ,
" show_label " : True ,
" label " : " Disease " ,
" colors " : None ,
" style " : { } ,
" elem_id " : None ,
2023-03-16 05:01:53 +08:00
" elem_classes " : None ,
2022-11-08 08:37:55 +08:00
" visible " : True ,
" value " : None ,
" interactive " : None ,
" root_url " : None ,
}
2022-03-30 04:23:30 +08:00
data = { " Name " : [ " Tom " , " nick " , " krish " , " jack " ] , " Age " : [ 20 , 21 , 19 , 18 ] }
df = pd . DataFrame ( data )
2022-11-08 08:37:55 +08:00
assert timeseries_output . postprocess ( df ) == {
" headers " : [ " Name " , " Age " ] ,
" data " : [ [ " Tom " , 20 ] , [ " nick " , 21 ] , [ " krish " , 19 ] , [ " jack " , 18 ] ] ,
}
2022-03-30 04:23:30 +08:00
timeseries_output = gr . Timeseries ( y = " Age " , label = " Disease " )
output = timeseries_output . postprocess ( df )
2022-11-08 08:37:55 +08:00
assert output == {
" headers " : [ " Name " , " Age " ] ,
" data " : [ [ " Tom " , 20 ] , [ " nick " , 21 ] , [ " krish " , 19 ] , [ " jack " , 18 ] ] ,
2022-07-12 18:35:20 +08:00
}
2022-04-07 06:40:28 +08:00
2022-03-30 04:23:30 +08:00
2022-11-08 08:37:55 +08:00
class TestNames :
2022-04-05 17:54:17 +08:00
# This test ensures that `components.get_component_instance()` works correctly when instantiating from components
2022-03-30 04:23:30 +08:00
def test_no_duplicate_uncased_names ( self ) :
subclasses = gr . components . Component . __subclasses__ ( )
unique_subclasses_uncased = set ( [ s . __name__ . lower ( ) for s in subclasses ] )
2022-11-08 08:37:55 +08:00
assert len ( subclasses ) == len ( unique_subclasses_uncased )
2022-03-30 04:23:30 +08:00
2022-11-08 08:37:55 +08:00
class TestLabel :
2022-04-06 04:35:04 +08:00
def test_component_functions ( self ) :
2022-04-05 17:54:17 +08:00
"""
2022-08-23 23:31:04 +08:00
Process , postprocess , deserialize
2022-04-05 17:54:17 +08:00
"""
2022-03-30 04:23:30 +08:00
y = " happy "
label_output = gr . Label ( )
label = label_output . postprocess ( y )
2022-11-08 08:37:55 +08:00
assert label == { " label " : " happy " }
assert json . load ( open ( label_output . deserialize ( label ) ) ) == label
2022-08-23 23:31:04 +08:00
y = { 3 : 0.7 , 1 : 0.2 , 0 : 0.1 }
2022-03-30 04:23:30 +08:00
label = label_output . postprocess ( y )
2022-11-08 08:37:55 +08:00
assert label == {
" label " : 3 ,
" confidences " : [
{ " label " : 3 , " confidence " : 0.7 } ,
{ " label " : 1 , " confidence " : 0.2 } ,
{ " label " : 0 , " confidence " : 0.1 } ,
] ,
}
2022-03-30 04:23:30 +08:00
label_output = gr . Label ( num_top_classes = 2 )
label = label_output . postprocess ( y )
2022-08-26 06:26:27 +08:00
2022-11-08 08:37:55 +08:00
assert label == {
" label " : 3 ,
" confidences " : [
{ " label " : 3 , " confidence " : 0.7 } ,
{ " label " : 1 , " confidence " : 0.2 } ,
] ,
}
with pytest . raises ( ValueError ) :
2022-03-30 04:23:30 +08:00
label_output . postprocess ( [ 1 , 2 , 3 ] )
2023-03-09 04:24:09 +08:00
test_file_dir = Path ( Path ( __file__ ) . parent , " test_files " )
path = str ( Path ( test_file_dir , " test_label_json.json " ) )
2022-12-16 04:37:09 +08:00
label_dict = label_output . postprocess ( path )
assert label_dict [ " label " ] == " web site "
2022-11-08 08:37:55 +08:00
assert label_output . get_config ( ) == {
" name " : " label " ,
" show_label " : True ,
" num_top_classes " : 2 ,
" value " : None ,
" label " : None ,
" style " : { } ,
" elem_id " : None ,
2023-03-16 05:01:53 +08:00
" elem_classes " : None ,
2022-11-08 08:37:55 +08:00
" visible " : True ,
" interactive " : None ,
" root_url " : None ,
2022-12-02 03:30:11 +08:00
" color " : None ,
2023-03-14 08:12:41 +08:00
" selectable " : False ,
2022-11-08 08:37:55 +08:00
}
2022-04-05 17:54:17 +08:00
2022-12-02 03:30:11 +08:00
def test_color_argument ( self ) :
label = gr . Label ( value = - 10 , color = " red " )
assert label . get_config ( ) [ " color " ] == " red "
update_1 = gr . Label . update ( value = " bad " , color = " brown " )
assert update_1 [ " color " ] == " brown "
update_2 = gr . Label . update ( value = " bad " , color = " #ff9966 " )
assert update_2 [ " color " ] == " #ff9966 "
update_3 = gr . Label . update (
value = { " bad " : 0.9 , " good " : 0.09 , " so-so " : 0.01 } , color = " green "
)
assert update_3 [ " color " ] == " green "
update_4 = gr . Label . update ( value = { " bad " : 0.8 , " good " : 0.18 , " so-so " : 0.02 } )
assert update_4 [ " color " ] is None
update_5 = gr . Label . update (
value = { " bad " : 0.8 , " good " : 0.18 , " so-so " : 0.02 } , color = None
)
assert update_5 [ " color " ] == " transparent "
2022-12-14 07:01:27 +08:00
def test_in_interface ( self ) :
2022-04-05 17:54:17 +08:00
"""
Interface , process
"""
2022-11-08 08:37:55 +08:00
x_img = " test/test_files/bus.png "
2022-03-30 04:23:30 +08:00
def rgb_distribution ( img ) :
rgb_dist = np . mean ( img , axis = ( 0 , 1 ) )
rgb_dist / = np . sum ( rgb_dist )
rgb_dist = np . round ( rgb_dist , decimals = 2 )
return {
" red " : rgb_dist [ 0 ] ,
" green " : rgb_dist [ 1 ] ,
" blue " : rgb_dist [ 2 ] ,
}
iface = gr . Interface ( rgb_distribution , " image " , " label " )
2022-11-08 08:37:55 +08:00
output_filepath = iface ( x_img )
with open ( output_filepath ) as fp :
assert json . load ( fp ) == {
2022-03-30 04:23:30 +08:00
" label " : " red " ,
" confidences " : [
{ " label " : " red " , " confidence " : 0.44 } ,
{ " label " : " green " , " confidence " : 0.28 } ,
{ " label " : " blue " , " confidence " : 0.28 } ,
] ,
2022-11-08 08:37:55 +08:00
}
2022-03-30 04:23:30 +08:00
2022-11-08 08:37:55 +08:00
class TestHighlightedText :
2022-07-26 03:16:00 +08:00
def test_postprocess ( self ) :
"""
postprocess
"""
component = gr . HighlightedText ( )
result = [
( " " , None ) ,
( " Wolfgang " , " PER " ) ,
( " lives in " , None ) ,
( " Berlin " , " LOC " ) ,
( " " , None ) ,
]
result_ = component . postprocess ( result )
2022-11-08 08:37:55 +08:00
assert result == result_
2022-07-26 03:16:00 +08:00
text = " Wolfgang lives in Berlin "
entities = [
{ " entity " : " PER " , " start " : 0 , " end " : 8 } ,
{ " entity " : " LOC " , " start " : 18 , " end " : 24 } ,
]
result_ = component . postprocess ( { " text " : text , " entities " : entities } )
2022-11-08 08:37:55 +08:00
assert result == result_
2022-07-26 03:16:00 +08:00
2022-12-10 09:42:44 +08:00
# Test split entity is merged when combine adjacent is set
text = " Wolfgang lives in Berlin "
entities = [
{ " entity " : " PER " , " start " : 0 , " end " : 4 } ,
{ " entity " : " PER " , " start " : 4 , " end " : 8 } ,
{ " entity " : " LOC " , " start " : 18 , " end " : 24 } ,
]
# After a merge empty entries are stripped except the leading one
result_after_merge = [
( " " , None ) ,
( " Wolfgang " , " PER " ) ,
( " lives in " , None ) ,
( " Berlin " , " LOC " ) ,
]
result_ = component . postprocess ( { " text " : text , " entities " : entities } )
assert result != result_
assert result_after_merge != result_
component = gr . HighlightedText ( combine_adjacent = True )
result_ = component . postprocess ( { " text " : text , " entities " : entities } )
assert result_after_merge == result_
component = gr . HighlightedText ( )
2022-09-10 15:32:06 +08:00
text = " Wolfgang lives in Berlin "
entities = [
{ " entity " : " LOC " , " start " : 18 , " end " : 24 } ,
{ " entity " : " PER " , " start " : 0 , " end " : 8 } ,
]
result_ = component . postprocess ( { " text " : text , " entities " : entities } )
2022-11-08 08:37:55 +08:00
assert result == result_
2022-09-10 15:32:06 +08:00
2022-07-26 03:16:00 +08:00
text = " I live there "
entities = [ ]
result_ = component . postprocess ( { " text " : text , " entities " : entities } )
2022-11-08 08:37:55 +08:00
assert [ ( text , None ) ] == result_
2022-07-26 03:16:00 +08:00
text = " Wolfgang "
entities = [
{ " entity " : " PER " , " start " : 0 , " end " : 8 } ,
]
result_ = component . postprocess ( { " text " : text , " entities " : entities } )
2022-11-08 08:37:55 +08:00
assert [ ( " " , None ) , ( text , " PER " ) , ( " " , None ) ] == result_
2022-07-26 03:16:00 +08:00
2022-04-06 04:35:04 +08:00
def test_component_functions ( self ) :
2022-04-05 17:54:17 +08:00
"""
2022-08-23 23:31:04 +08:00
get_config
2022-04-05 17:54:17 +08:00
"""
2022-03-30 04:23:30 +08:00
ht_output = gr . HighlightedText ( color_map = { " pos " : " green " , " neg " : " red " } )
2022-11-08 08:37:55 +08:00
assert ht_output . get_config ( ) == {
" color_map " : { " pos " : " green " , " neg " : " red " } ,
" name " : " highlightedtext " ,
" show_label " : True ,
" label " : None ,
" show_legend " : False ,
" style " : { } ,
" elem_id " : None ,
2023-03-16 05:01:53 +08:00
" elem_classes " : None ,
2022-11-08 08:37:55 +08:00
" visible " : True ,
" value " : None ,
" interactive " : None ,
" root_url " : None ,
2023-03-14 08:12:41 +08:00
" selectable " : False ,
2022-11-08 08:37:55 +08:00
}
2022-03-30 04:23:30 +08:00
2022-11-08 08:37:55 +08:00
def test_in_interface ( self ) :
2022-04-05 17:54:17 +08:00
"""
Interface , process
"""
2022-03-30 04:23:30 +08:00
def highlight_vowels ( sentence ) :
phrases , cur_phrase = [ ] , " "
vowels , mode = " aeiou " , None
for letter in sentence :
letter_mode = " vowel " if letter in vowels else " non "
if mode is None :
mode = letter_mode
elif mode != letter_mode :
phrases . append ( ( cur_phrase , mode ) )
cur_phrase = " "
mode = letter_mode
cur_phrase + = letter
phrases . append ( ( cur_phrase , mode ) )
return phrases
iface = gr . Interface ( highlight_vowels , " text " , " highlight " )
2022-11-08 08:37:55 +08:00
output_filepath = iface ( " Helloooo " )
with open ( output_filepath ) as fp :
output = json . load ( fp )
assert output == [
[ " H " , " non " ] ,
[ " e " , " vowel " ] ,
[ " ll " , " non " ] ,
[ " oooo " , " vowel " ] ,
]
2022-03-30 04:23:30 +08:00
2022-11-30 04:26:21 +08:00
class TestChatbot :
def test_component_functions ( self ) :
"""
Postprocess , get_config
"""
chatbot = gr . Chatbot ( )
2023-03-22 00:37:24 +08:00
assert chatbot . postprocess ( [ [ " You are **cool** " , " so are *you* " ] ] ) == [
[ " You are <strong>cool</strong> " , " so are <em>you</em> " ]
2022-11-30 04:26:21 +08:00
]
2023-03-13 22:21:03 +08:00
multimodal_msg = [
2023-03-22 00:37:24 +08:00
[ ( " test/test_files/video_sample.mp4 " , ) , " cool video " ] ,
[ ( " test/test_files/audio_sample.wav " , ) , " cool audio " ] ,
[ ( " test/test_files/bus.png " , " A bus " ) , " cool pic " ] ,
2023-03-13 22:21:03 +08:00
]
processed_multimodal_msg = [
2023-03-22 00:37:24 +08:00
[
2023-03-13 22:21:03 +08:00
{
2023-03-22 00:37:24 +08:00
" name " : " video_sample.mp4 " ,
2023-03-13 22:21:03 +08:00
" mime_type " : " video/mp4 " ,
" alt_text " : None ,
" data " : None ,
" is_file " : True ,
} ,
" cool video " ,
2023-03-22 00:37:24 +08:00
] ,
[
2023-03-13 22:21:03 +08:00
{
2023-03-22 00:37:24 +08:00
" name " : " audio_sample.wav " ,
2023-03-13 22:21:03 +08:00
" mime_type " : " audio/wav " ,
" alt_text " : None ,
" data " : None ,
" is_file " : True ,
} ,
" cool audio " ,
2023-03-22 00:37:24 +08:00
] ,
[
2023-03-13 22:21:03 +08:00
{
2023-03-22 00:37:24 +08:00
" name " : " bus.png " ,
" mime_type " : " image/png " ,
" alt_text " : " A bus " ,
2023-03-13 22:21:03 +08:00
" data " : None ,
" is_file " : True ,
} ,
" cool pic " ,
2023-03-22 00:37:24 +08:00
] ,
2023-03-13 22:21:03 +08:00
]
2023-03-22 00:37:24 +08:00
postprocessed_multimodal_msg = chatbot . postprocess ( multimodal_msg )
postprocessed_multimodal_msg_base_names = [ ]
for x , y in postprocessed_multimodal_msg :
if isinstance ( x , dict ) :
x [ " name " ] = os . path . basename ( x [ " name " ] )
postprocessed_multimodal_msg_base_names . append ( [ x , y ] )
assert postprocessed_multimodal_msg_base_names == processed_multimodal_msg
preprocessed_multimodal_msg = chatbot . preprocess ( processed_multimodal_msg )
multimodal_msg_base_names = [ ]
for x , y in multimodal_msg :
if isinstance ( x , tuple ) :
if len ( x ) > 1 :
new_x = ( os . path . basename ( x [ 0 ] ) , x [ 1 ] )
else :
new_x = ( os . path . basename ( x [ 0 ] ) , )
multimodal_msg_base_names . append ( [ new_x , y ] )
assert multimodal_msg_base_names == preprocessed_multimodal_msg
2023-03-13 22:21:03 +08:00
2022-11-30 04:26:21 +08:00
assert chatbot . get_config ( ) == {
" value " : [ ] ,
" label " : None ,
" show_label " : True ,
" interactive " : None ,
" name " : " chatbot " ,
" visible " : True ,
" elem_id " : None ,
2023-03-16 05:01:53 +08:00
" elem_classes " : None ,
2022-11-30 04:26:21 +08:00
" style " : { } ,
" root_url " : None ,
2023-03-14 08:12:41 +08:00
" selectable " : False ,
2022-11-30 04:26:21 +08:00
}
2022-11-08 08:37:55 +08:00
class TestJSON :
2022-04-06 04:35:04 +08:00
def test_component_functions ( self ) :
2022-04-05 17:54:17 +08:00
"""
2022-08-23 23:31:04 +08:00
Postprocess
2022-04-05 17:54:17 +08:00
"""
2022-03-30 04:23:30 +08:00
js_output = gr . JSON ( )
2022-11-08 08:37:55 +08:00
assert js_output . postprocess ( ' { " a " :1, " b " : 2} ' ) , ' " { \\ " a \\ " :1, \\ " b \\ " : 2} " '
assert js_output . get_config ( ) == {
" style " : { } ,
" elem_id " : None ,
2023-03-16 05:01:53 +08:00
" elem_classes " : None ,
2022-11-08 08:37:55 +08:00
" visible " : True ,
" value " : None ,
" show_label " : True ,
" label " : None ,
" name " : " json " ,
" interactive " : None ,
" root_url " : None ,
}
2022-03-30 04:23:30 +08:00
2022-11-08 08:37:55 +08:00
@pytest.mark.asyncio
2022-08-11 06:29:14 +08:00
async def test_in_interface ( self ) :
2022-04-05 17:54:17 +08:00
"""
Interface , process
"""
2022-03-30 04:23:30 +08:00
def get_avg_age_per_gender ( data ) :
return {
" M " : int ( data [ data [ " gender " ] == " M " ] . mean ( ) ) ,
" F " : int ( data [ data [ " gender " ] == " F " ] . mean ( ) ) ,
" O " : int ( data [ data [ " gender " ] == " O " ] . mean ( ) ) ,
}
iface = gr . Interface (
get_avg_age_per_gender ,
2022-04-27 17:22:16 +08:00
gr . Dataframe ( headers = [ " gender " , " age " ] ) ,
2022-03-30 04:23:30 +08:00
" json " ,
)
y_data = [
[ " M " , 30 ] ,
[ " F " , 20 ] ,
[ " M " , 40 ] ,
[ " O " , 20 ] ,
[ " F " , 30 ] ,
]
2022-11-08 08:37:55 +08:00
assert (
2022-12-28 05:54:47 +08:00
await iface . process_api (
0 , [ { " data " : y_data , " headers " : [ " gender " , " age " ] } ] , state = { }
)
2022-11-08 08:37:55 +08:00
) [ " data " ] [ 0 ] == {
" M " : 35 ,
" F " : 25 ,
" O " : 20 ,
}
2022-03-30 04:23:30 +08:00
2022-11-08 08:37:55 +08:00
class TestHTML :
2022-04-06 04:35:04 +08:00
def test_component_functions ( self ) :
2022-04-05 17:54:17 +08:00
"""
2022-05-11 08:11:43 +08:00
get_config
2022-04-05 17:54:17 +08:00
"""
html_component = gr . components . HTML ( " #Welcome onboard " , label = " HTML Input " )
2022-11-08 08:37:55 +08:00
assert {
" style " : { } ,
" elem_id " : None ,
2023-03-16 05:01:53 +08:00
" elem_classes " : None ,
2022-11-08 08:37:55 +08:00
" visible " : True ,
" value " : " #Welcome onboard " ,
" show_label " : True ,
" label " : " HTML Input " ,
" name " : " html " ,
" interactive " : None ,
" root_url " : None ,
} == html_component . get_config ( )
2022-04-05 17:54:17 +08:00
2022-12-14 07:01:27 +08:00
def test_in_interface ( self ) :
2022-04-05 17:54:17 +08:00
"""
Interface , process
"""
2022-03-30 04:23:30 +08:00
def bold_text ( text ) :
return " <strong> " + text + " </strong> "
iface = gr . Interface ( bold_text , " text " , " html " )
2022-11-08 08:37:55 +08:00
assert iface ( " test " ) == " <strong>test</strong> "
2022-03-30 04:23:30 +08:00
2022-11-30 03:26:17 +08:00
class TestMarkdown :
def test_component_functions ( self ) :
markdown_component = gr . Markdown ( " # Let ' s learn about $x$ " , label = " Markdown " )
2022-12-02 03:30:11 +08:00
assert markdown_component . get_config ( ) [ " value " ] . startswith (
2023-02-23 06:34:38 +08:00
""" <h1>Let’ s learn about <span class= " math inline " ><span style= \' font-size: 0px \' >x</span><svg xmlns:xlink= " http://www.w3.org/1999/xlink " height= " 0.9678125em " viewBox= " 0 0 11.6 19.35625 " xmlns= " http://www.w3.org/2000/svg " version= " 1.1 " > \n \n <defs> \n <style type= " text/css " >* { stroke-linejoin: round; stroke-linecap: butt}</style> \n </defs> \n <g id= " figure_1 " > \n <g id= " patch_1 " > """
2022-12-02 03:30:11 +08:00
)
2022-11-30 03:26:17 +08:00
2022-12-14 07:01:27 +08:00
def test_in_interface ( self ) :
2022-11-30 03:26:17 +08:00
"""
Interface , process
"""
iface = gr . Interface ( lambda x : x , " text " , " markdown " )
input_data = " Here ' s an [image](https://gradio.app/images/gradio_logo.png) "
output_data = iface ( input_data )
assert (
output_data
2023-02-21 19:33:07 +08:00
== """ <p>Here’ s an <a href= " https://gradio.app/images/gradio_logo.png " target= " _blank " >image</a></p> \n """
2022-11-30 03:26:17 +08:00
)
2022-11-08 08:37:55 +08:00
class TestModel3D :
2022-05-21 08:53:27 +08:00
def test_component_functions ( self ) :
"""
get_config
"""
2022-08-23 23:31:04 +08:00
component = gr . components . Model3D ( None , label = " Model " )
2022-11-08 08:37:55 +08:00
assert {
2023-01-18 04:47:40 +08:00
" clearColor " : [ 0 , 0 , 0 , 0 ] ,
2022-11-08 08:37:55 +08:00
" value " : None ,
" label " : " Model " ,
" show_label " : True ,
" interactive " : None ,
" root_url " : None ,
" name " : " model3d " ,
" visible " : True ,
" elem_id " : None ,
2023-03-16 05:01:53 +08:00
" elem_classes " : None ,
2022-11-08 08:37:55 +08:00
" style " : { } ,
} == component . get_config ( )
2022-05-21 08:53:27 +08:00
2022-12-16 04:37:09 +08:00
file = " test/test_files/Box.gltf "
output1 = component . postprocess ( file )
output2 = component . postprocess ( file )
assert output1 == output2
2022-12-14 07:01:27 +08:00
def test_in_interface ( self ) :
2022-05-21 08:53:27 +08:00
"""
Interface , process
"""
iface = gr . Interface ( lambda x : x , " model3d " , " model3d " )
2022-11-08 08:37:55 +08:00
input_data = " test/test_files/Box.gltf "
output_data = iface ( input_data )
assert output_data . endswith ( " .gltf " )
2022-05-21 08:53:27 +08:00
2022-11-08 08:37:55 +08:00
class TestColorPicker :
2022-07-05 13:50:17 +08:00
def test_component_functions ( self ) :
"""
2023-03-17 05:22:25 +08:00
Preprocess , postprocess , serialize , tokenize , get_config
2022-07-05 13:50:17 +08:00
"""
color_picker_input = gr . ColorPicker ( )
2022-11-08 08:37:55 +08:00
assert color_picker_input . preprocess ( " #000000 " ) == " #000000 "
assert color_picker_input . postprocess ( " #000000 " ) == " #000000 "
assert color_picker_input . postprocess ( None ) is None
assert color_picker_input . postprocess ( " #FFFFFF " ) == " #FFFFFF "
assert color_picker_input . serialize ( " #000000 " , True ) == " #000000 "
2022-07-05 13:50:17 +08:00
color_picker_input . interpretation_replacement = " unknown "
2022-11-08 08:37:55 +08:00
assert color_picker_input . get_config ( ) == {
" value " : None ,
" show_label " : True ,
" label " : None ,
" style " : { } ,
" elem_id " : None ,
2023-03-16 05:01:53 +08:00
" elem_classes " : None ,
2022-11-08 08:37:55 +08:00
" visible " : True ,
" interactive " : None ,
" root_url " : None ,
" name " : " colorpicker " ,
}
2022-07-05 13:50:17 +08:00
2022-11-08 08:37:55 +08:00
def test_in_interface_as_input ( self ) :
2022-07-05 13:50:17 +08:00
"""
Interface , process , interpret ,
"""
iface = gr . Interface ( lambda x : x , " colorpicker " , " colorpicker " )
2022-11-08 08:37:55 +08:00
assert iface ( " #000000 " ) == " #000000 "
2022-07-05 13:50:17 +08:00
2022-11-08 08:37:55 +08:00
def test_in_interface_as_output ( self ) :
2022-07-05 13:50:17 +08:00
"""
Interface , process
"""
iface = gr . Interface ( lambda x : x , " colorpicker " , gr . ColorPicker ( ) )
2022-11-08 08:37:55 +08:00
assert iface ( " #000000 " ) == " #000000 "
2022-07-05 13:50:17 +08:00
def test_static ( self ) :
"""
postprocess
"""
component = gr . ColorPicker ( " #000000 " )
2022-11-08 08:37:55 +08:00
assert component . get_config ( ) . get ( " value " ) == " #000000 "
2022-07-05 13:50:17 +08:00
2022-10-13 00:37:58 +08:00
class TestCarousel :
def test_deprecation ( self ) :
2023-03-09 04:24:09 +08:00
test_file_dir = Path ( Path ( __file__ ) . parent , " test_files " )
2022-10-13 00:37:58 +08:00
with pytest . raises ( DeprecationWarning ) :
2023-03-09 04:24:09 +08:00
gr . Carousel ( [ Path ( test_file_dir , " bus.png " ) ] )
2022-10-13 00:37:58 +08:00
def test_deprecation_in_interface ( self ) :
with pytest . raises ( DeprecationWarning ) :
gr . Interface ( lambda x : [ " lion.jpg " ] , " textbox " , " carousel " )
def test_deprecation_in_blocks ( self ) :
with pytest . raises ( DeprecationWarning ) :
with gr . Blocks ( ) :
gr . Textbox ( )
gr . Carousel ( )
2022-10-07 13:08:30 +08:00
class TestGallery :
@patch ( " uuid.uuid4 " , return_value = " my-uuid " )
def test_gallery ( self , mock_uuid ) :
gallery = gr . Gallery ( )
2023-03-09 04:24:09 +08:00
test_file_dir = Path ( Path ( __file__ ) . parent , " test_files " )
2022-10-07 13:08:30 +08:00
data = [
2023-03-24 06:33:44 +08:00
client_utils . encode_file_to_base64 ( Path ( test_file_dir , " bus.png " ) ) ,
client_utils . encode_file_to_base64 ( Path ( test_file_dir , " cheetah1.jpg " ) ) ,
2022-10-07 13:08:30 +08:00
]
2022-08-09 09:28:07 +08:00
2022-10-07 13:08:30 +08:00
with tempfile . TemporaryDirectory ( ) as tmpdir :
path = gallery . deserialize ( data , tmpdir )
assert path . endswith ( " my-uuid " )
data_restored = gallery . serialize ( path )
2022-12-16 04:37:09 +08:00
data_restored = [ d [ 0 ] [ " data " ] for d in data_restored ]
assert sorted ( data ) == sorted ( data_restored )
2022-08-17 01:21:13 +08:00
2022-08-30 00:53:05 +08:00
class TestState :
def test_as_component ( self ) :
state = gr . State ( value = 5 )
assert state . preprocess ( 10 ) == 10
assert state . preprocess ( " abc " ) == " abc "
assert state . stateful
@pytest.mark.asyncio
async def test_in_interface ( self ) :
def test ( x , y = " def " ) :
return ( x + y , x + y )
io = gr . Interface ( test , [ " text " , " state " ] , [ " text " , " state " ] )
result = await io . call_function ( 0 , [ " abc " ] )
2022-09-08 22:35:31 +08:00
assert result [ " prediction " ] [ 0 ] == " abc def "
result = await io . call_function ( 0 , [ " abc " , result [ " prediction " ] [ 0 ] ] )
assert result [ " prediction " ] [ 0 ] == " abcabc def "
2022-08-30 00:53:05 +08:00
@pytest.mark.asyncio
async def test_in_blocks ( self ) :
with gr . Blocks ( ) as demo :
score = gr . State ( )
btn = gr . Button ( )
btn . click ( lambda x : x + 1 , score , score )
result = await demo . call_function ( 0 , [ 0 ] )
2022-09-08 22:35:31 +08:00
assert result [ " prediction " ] == 1
result = await demo . call_function ( 0 , [ result [ " prediction " ] ] )
assert result [ " prediction " ] == 2
2022-08-30 00:53:05 +08:00
@pytest.mark.asyncio
async def test_variable_for_backwards_compatibility ( self ) :
with gr . Blocks ( ) as demo :
score = gr . Variable ( )
btn = gr . Button ( )
btn . click ( lambda x : x + 1 , score , score )
result = await demo . call_function ( 0 , [ 0 ] )
2022-09-08 22:35:31 +08:00
assert result [ " prediction " ] == 1
result = await demo . call_function ( 0 , [ result [ " prediction " ] ] )
assert result [ " prediction " ] == 2
2022-08-30 00:53:05 +08:00
2022-09-01 03:46:43 +08:00
def test_dataframe_as_example_converts_dataframes ( ) :
df_comp = gr . Dataframe ( )
assert df_comp . as_example ( pd . DataFrame ( { " a " : [ 1 , 2 , 3 , 4 ] , " b " : [ 5 , 6 , 7 , 8 ] } ) ) == [
[ 1 , 5 ] ,
[ 2 , 6 ] ,
[ 3 , 7 ] ,
[ 4 , 8 ] ,
]
assert df_comp . as_example ( np . array ( [ [ 1 , 2 ] , [ 3 , 4.0 ] ] ) ) == [ [ 1.0 , 2.0 ] , [ 3.0 , 4.0 ] ]
2022-11-02 05:36:38 +08:00
@pytest.mark.parametrize ( " component " , [ gr . Model3D , gr . File , gr . Audio ] )
2022-09-01 03:46:43 +08:00
def test_as_example_returns_file_basename ( component ) :
component = component ( )
assert component . as_example ( " /home/freddy/sources/example.ext " ) == " example.ext "
2022-11-02 05:36:38 +08:00
assert component . as_example ( None ) == " "
2022-09-01 03:46:43 +08:00
@patch ( " gradio.components.IOComponent.as_example " )
2022-11-04 03:35:56 +08:00
@patch ( " gradio.components.Image.as_example " )
2022-09-01 03:46:43 +08:00
@patch ( " gradio.components.File.as_example " )
@patch ( " gradio.components.Dataframe.as_example " )
@patch ( " gradio.components.Model3D.as_example " )
def test_dataset_calls_as_example ( * mocks ) :
gr . Dataset (
2022-11-04 03:35:56 +08:00
components = [ gr . Dataframe ( ) , gr . File ( ) , gr . Image ( ) , gr . Model3D ( ) , gr . Textbox ( ) ] ,
2022-09-01 03:46:43 +08:00
samples = [
[
pd . DataFrame ( { " a " : np . array ( [ 1 , 2 , 3 ] ) } ) ,
" foo.png " ,
" bar.jpeg " ,
" duck.obj " ,
2022-11-04 03:35:56 +08:00
" hello " ,
2022-09-01 03:46:43 +08:00
]
] ,
)
assert all ( [ m . called for m in mocks ] )
2022-12-09 23:14:07 +08:00
cars = vega_datasets . data . cars ( )
2022-12-21 00:13:51 +08:00
stocks = vega_datasets . data . stocks ( )
2023-02-10 05:42:25 +08:00
barley = vega_datasets . data . barley ( )
simple = pd . DataFrame (
{
" a " : [ " A " , " B " , " C " , " D " , " E " , " F " , " G " , " H " , " I " ] ,
" b " : [ 28 , 55 , 43 , 91 , 81 , 53 , 19 , 87 , 52 ] ,
}
)
2022-12-09 23:14:07 +08:00
class TestScatterPlot :
2023-02-18 05:47:06 +08:00
@patch.dict ( " sys.modules " , { " bokeh " : MagicMock ( __version__ = " 3.0.3 " ) } )
2022-12-09 23:14:07 +08:00
def test_get_config ( self ) :
2023-02-18 05:47:06 +08:00
2022-12-09 23:14:07 +08:00
assert gr . ScatterPlot ( ) . get_config ( ) == {
" caption " : None ,
" elem_id " : None ,
2023-03-16 05:01:53 +08:00
" elem_classes " : None ,
2022-12-09 23:14:07 +08:00
" interactive " : None ,
" label " : None ,
" name " : " plot " ,
" root_url " : None ,
" show_label " : True ,
" style " : { } ,
" value " : None ,
" visible " : True ,
2023-02-18 05:47:06 +08:00
" bokeh_version " : " 3.0.3 " ,
2022-12-09 23:14:07 +08:00
}
def test_no_color ( self ) :
plot = gr . ScatterPlot (
x = " Horsepower " ,
y = " Miles_per_Gallon " ,
tooltip = " Name " ,
title = " Car Data " ,
x_title = " Horse " ,
)
output = plot . postprocess ( cars )
assert sorted ( list ( output . keys ( ) ) ) == [ " chart " , " plot " , " type " ]
config = json . loads ( output [ " plot " ] )
assert config [ " encoding " ] [ " x " ] [ " field " ] == " Horsepower "
assert config [ " encoding " ] [ " x " ] [ " title " ] == " Horse "
assert config [ " encoding " ] [ " y " ] [ " field " ] == " Miles_per_Gallon "
assert config [ " selection " ] == {
" selector001 " : {
" bind " : " scales " ,
" encodings " : [ " x " , " y " ] ,
" type " : " interval " ,
}
}
assert config [ " title " ] == " Car Data "
assert " height " not in config
assert " width " not in config
def test_no_interactive ( self ) :
plot = gr . ScatterPlot (
x = " Horsepower " , y = " Miles_per_Gallon " , tooltip = " Name " , interactive = False
)
output = plot . postprocess ( cars )
assert sorted ( list ( output . keys ( ) ) ) == [ " chart " , " plot " , " type " ]
config = json . loads ( output [ " plot " ] )
assert " selection " not in config
def test_height_width ( self ) :
plot = gr . ScatterPlot (
x = " Horsepower " , y = " Miles_per_Gallon " , height = 100 , width = 200
)
output = plot . postprocess ( cars )
assert sorted ( list ( output . keys ( ) ) ) == [ " chart " , " plot " , " type " ]
config = json . loads ( output [ " plot " ] )
assert config [ " height " ] == 100
assert config [ " width " ] == 200
2022-12-21 00:13:51 +08:00
def test_xlim_ylim ( self ) :
plot = gr . ScatterPlot (
x = " Horsepower " , y = " Miles_per_Gallon " , x_lim = [ 200 , 400 ] , y_lim = [ 300 , 500 ]
)
output = plot . postprocess ( cars )
config = json . loads ( output [ " plot " ] )
assert config [ " encoding " ] [ " x " ] [ " scale " ] == { " domain " : [ 200 , 400 ] }
assert config [ " encoding " ] [ " y " ] [ " scale " ] == { " domain " : [ 300 , 500 ] }
2022-12-09 23:14:07 +08:00
def test_color_encoding ( self ) :
plot = gr . ScatterPlot (
x = " Horsepower " ,
y = " Miles_per_Gallon " ,
tooltip = " Name " ,
title = " Car Data " ,
color = " Origin " ,
)
output = plot . postprocess ( cars )
config = json . loads ( output [ " plot " ] )
assert config [ " encoding " ] [ " color " ] [ " field " ] == " Origin "
assert config [ " encoding " ] [ " color " ] [ " scale " ] == {
" domain " : [ " USA " , " Europe " , " Japan " ] ,
" range " : [ 0 , 1 , 2 ] ,
}
assert config [ " encoding " ] [ " color " ] [ " type " ] == " nominal "
def test_two_encodings ( self ) :
plot = gr . ScatterPlot (
show_label = False ,
title = " Two encodings " ,
x = " Horsepower " ,
y = " Miles_per_Gallon " ,
color = " Acceleration " ,
shape = " Origin " ,
)
output = plot . postprocess ( cars )
config = json . loads ( output [ " plot " ] )
assert config [ " encoding " ] [ " color " ] [ " field " ] == " Acceleration "
assert config [ " encoding " ] [ " color " ] [ " scale " ] == {
" domain " : [ cars . Acceleration . min ( ) , cars . Acceleration . max ( ) ] ,
" range " : [ 0 , 1 ] ,
}
assert config [ " encoding " ] [ " color " ] [ " type " ] == " quantitative "
assert config [ " encoding " ] [ " shape " ] [ " field " ] == " Origin "
assert config [ " encoding " ] [ " shape " ] [ " type " ] == " nominal "
2022-12-21 00:13:51 +08:00
def test_legend_position ( self ) :
plot = gr . ScatterPlot (
show_label = False ,
title = " Two encodings " ,
x = " Horsepower " ,
y = " Miles_per_Gallon " ,
color = " Acceleration " ,
color_legend_position = " none " ,
color_legend_title = " Foo " ,
shape = " Origin " ,
shape_legend_position = " none " ,
shape_legend_title = " Bar " ,
size = " Acceleration " ,
size_legend_title = " Accel " ,
size_legend_position = " none " ,
)
output = plot . postprocess ( cars )
config = json . loads ( output [ " plot " ] )
assert config [ " encoding " ] [ " color " ] [ " legend " ] is None
assert config [ " encoding " ] [ " shape " ] [ " legend " ] is None
assert config [ " encoding " ] [ " size " ] [ " legend " ] is None
output = gr . ScatterPlot . update (
value = cars ,
title = " Two encodings " ,
x = " Horsepower " ,
y = " Miles_per_Gallon " ,
color = " Acceleration " ,
color_legend_position = " top " ,
color_legend_title = " Foo " ,
shape = " Origin " ,
shape_legend_position = " bottom " ,
shape_legend_title = " Bar " ,
size = " Acceleration " ,
size_legend_title = " Accel " ,
size_legend_position = " left " ,
)
config = json . loads ( output [ " value " ] [ " plot " ] )
assert config [ " encoding " ] [ " color " ] [ " legend " ] [ " orient " ] == " top "
assert config [ " encoding " ] [ " shape " ] [ " legend " ] [ " orient " ] == " bottom "
assert config [ " encoding " ] [ " size " ] [ " legend " ] [ " orient " ] == " left "
2022-12-09 23:14:07 +08:00
def test_update ( self ) :
output = gr . ScatterPlot . update ( value = cars , x = " Horsepower " , y = " Miles_per_Gallon " )
postprocessed = gr . ScatterPlot ( ) . postprocess ( output [ " value " ] )
assert postprocessed == output [ " value " ]
def test_update_visibility ( self ) :
output = gr . ScatterPlot . update ( visible = False )
assert not output [ " visible " ]
assert output [ " value " ] is gr . components . _Keywords . NO_VALUE
def test_update_errors ( self ) :
with pytest . raises (
ValueError , match = " In order to update plot properties the value parameter "
) :
gr . ScatterPlot . update ( x = " foo " , y = " bar " )
with pytest . raises (
ValueError ,
match = " In order to update plot properties, the x and y axis data " ,
) :
gr . ScatterPlot . update ( value = cars , x = " foo " )
def test_scatterplot_accepts_fn_as_value ( self ) :
plot = gr . ScatterPlot (
value = lambda : cars . sample ( frac = 0.1 , replace = False ) ,
x = " Horsepower " ,
y = " Miles_per_Gallon " ,
color = " Origin " ,
)
assert isinstance ( plot . value , dict )
assert isinstance ( plot . value [ " plot " ] , str )
2022-12-21 00:13:51 +08:00
class TestLinePlot :
2023-02-18 05:47:06 +08:00
@patch.dict ( " sys.modules " , { " bokeh " : MagicMock ( __version__ = " 3.0.3 " ) } )
2022-12-21 00:13:51 +08:00
def test_get_config ( self ) :
assert gr . LinePlot ( ) . get_config ( ) == {
" caption " : None ,
" elem_id " : None ,
2023-03-16 05:01:53 +08:00
" elem_classes " : None ,
2022-12-21 00:13:51 +08:00
" interactive " : None ,
" label " : None ,
" name " : " plot " ,
" root_url " : None ,
" show_label " : True ,
" style " : { } ,
" value " : None ,
" visible " : True ,
2023-02-18 05:47:06 +08:00
" bokeh_version " : " 3.0.3 " ,
2022-12-21 00:13:51 +08:00
}
def test_no_color ( self ) :
plot = gr . LinePlot (
x = " date " ,
y = " price " ,
tooltip = [ " symbol " , " price " ] ,
title = " Stock Performance " ,
x_title = " Trading Day " ,
)
output = plot . postprocess ( stocks )
assert sorted ( list ( output . keys ( ) ) ) == [ " chart " , " plot " , " type " ]
config = json . loads ( output [ " plot " ] )
for layer in config [ " layer " ] :
assert layer [ " mark " ] [ " type " ] in [ " line " , " point " ]
assert layer [ " encoding " ] [ " x " ] [ " field " ] == " date "
assert layer [ " encoding " ] [ " x " ] [ " title " ] == " Trading Day "
assert layer [ " encoding " ] [ " y " ] [ " field " ] == " price "
assert config [ " title " ] == " Stock Performance "
assert " height " not in config
assert " width " not in config
def test_height_width ( self ) :
plot = gr . LinePlot ( x = " date " , y = " price " , height = 100 , width = 200 )
output = plot . postprocess ( stocks )
assert sorted ( list ( output . keys ( ) ) ) == [ " chart " , " plot " , " type " ]
config = json . loads ( output [ " plot " ] )
assert config [ " height " ] == 100
assert config [ " width " ] == 200
output = gr . LinePlot . update ( stocks , x = " date " , y = " price " , height = 100 , width = 200 )
config = json . loads ( output [ " value " ] [ " plot " ] )
assert config [ " height " ] == 100
assert config [ " width " ] == 200
def test_xlim_ylim ( self ) :
plot = gr . LinePlot ( x = " date " , y = " price " , x_lim = [ 200 , 400 ] , y_lim = [ 300 , 500 ] )
output = plot . postprocess ( stocks )
config = json . loads ( output [ " plot " ] )
for layer in config [ " layer " ] :
assert layer [ " encoding " ] [ " x " ] [ " scale " ] == { " domain " : [ 200 , 400 ] }
assert layer [ " encoding " ] [ " y " ] [ " scale " ] == { " domain " : [ 300 , 500 ] }
def test_color_encoding ( self ) :
plot = gr . LinePlot (
x = " date " , y = " price " , tooltip = " symbol " , color = " symbol " , overlay_point = True
)
output = plot . postprocess ( stocks )
config = json . loads ( output [ " plot " ] )
for layer in config [ " layer " ] :
assert layer [ " encoding " ] [ " color " ] [ " field " ] == " symbol "
assert layer [ " encoding " ] [ " color " ] [ " scale " ] == {
" domain " : [ " MSFT " , " AMZN " , " IBM " , " GOOG " , " AAPL " ] ,
" range " : [ 0 , 1 , 2 , 3 , 4 ] ,
}
assert layer [ " encoding " ] [ " color " ] [ " type " ] == " nominal "
if layer [ " mark " ] [ " type " ] == " point " :
assert layer [ " encoding " ] [ " opacity " ] == { }
def test_two_encodings ( self ) :
output = gr . LinePlot . update (
value = stocks ,
title = " Two encodings " ,
x = " date " ,
y = " price " ,
color = " symbol " ,
stroke_dash = " symbol " ,
color_legend_title = " Color " ,
stroke_dash_legend_title = " Stroke Dash " ,
)
config = json . loads ( output [ " value " ] [ " plot " ] )
for layer in config [ " layer " ] :
if layer [ " mark " ] [ " type " ] == " point " :
assert layer [ " encoding " ] [ " opacity " ] == { " value " : 0 }
if layer [ " mark " ] [ " type " ] == " line " :
assert layer [ " encoding " ] [ " strokeDash " ] [ " field " ] == " symbol "
assert (
layer [ " encoding " ] [ " strokeDash " ] [ " legend " ] [ " title " ] == " Stroke Dash "
)
def test_legend_position ( self ) :
plot = gr . LinePlot (
value = stocks ,
title = " Two encodings " ,
x = " date " ,
y = " price " ,
color = " symbol " ,
stroke_dash = " symbol " ,
color_legend_position = " none " ,
stroke_dash_legend_position = " none " ,
)
output = plot . postprocess ( stocks )
config = json . loads ( output [ " plot " ] )
for layer in config [ " layer " ] :
if layer [ " mark " ] [ " type " ] == " point " :
assert layer [ " encoding " ] [ " color " ] [ " legend " ] is None
if layer [ " mark " ] [ " type " ] == " line " :
assert layer [ " encoding " ] [ " strokeDash " ] [ " legend " ] is None
assert layer [ " encoding " ] [ " color " ] [ " legend " ] is None
output = gr . LinePlot . update (
value = stocks ,
title = " Two encodings " ,
x = " date " ,
y = " price " ,
color = " symbol " ,
stroke_dash = " symbol " ,
color_legend_position = " top-right " ,
stroke_dash_legend_position = " top-left " ,
)
config = json . loads ( output [ " value " ] [ " plot " ] )
for layer in config [ " layer " ] :
if layer [ " mark " ] [ " type " ] == " point " :
assert layer [ " encoding " ] [ " color " ] [ " legend " ] [ " orient " ] == " top-right "
if layer [ " mark " ] [ " type " ] == " line " :
assert layer [ " encoding " ] [ " strokeDash " ] [ " legend " ] [ " orient " ] == " top-left "
assert layer [ " encoding " ] [ " color " ] [ " legend " ] [ " orient " ] == " top-right "
def test_update_visibility ( self ) :
output = gr . LinePlot . update ( visible = False )
assert not output [ " visible " ]
assert output [ " value " ] is gr . components . _Keywords . NO_VALUE
def test_update_errors ( self ) :
with pytest . raises (
ValueError , match = " In order to update plot properties the value parameter "
) :
gr . LinePlot . update ( x = " foo " , y = " bar " )
with pytest . raises (
ValueError ,
match = " In order to update plot properties, the x and y axis data " ,
) :
gr . LinePlot . update ( value = stocks , x = " foo " )
def test_lineplot_accepts_fn_as_value ( self ) :
plot = gr . LinePlot (
value = lambda : stocks . sample ( frac = 0.1 , replace = False ) ,
x = " date " ,
y = " price " ,
color = " symbol " ,
)
assert isinstance ( plot . value , dict )
assert isinstance ( plot . value [ " plot " ] , str )
2023-02-10 05:42:25 +08:00
class TestBarPlot :
2023-02-18 05:47:06 +08:00
@patch.dict ( " sys.modules " , { " bokeh " : MagicMock ( __version__ = " 3.0.3 " ) } )
2023-02-10 05:42:25 +08:00
def test_get_config ( self ) :
assert gr . BarPlot ( ) . get_config ( ) == {
" caption " : None ,
" elem_id " : None ,
2023-03-16 05:01:53 +08:00
" elem_classes " : None ,
2023-02-10 05:42:25 +08:00
" interactive " : None ,
" label " : None ,
" name " : " plot " ,
" root_url " : None ,
" show_label " : True ,
" style " : { } ,
" value " : None ,
" visible " : True ,
2023-02-18 05:47:06 +08:00
" bokeh_version " : " 3.0.3 " ,
2023-02-10 05:42:25 +08:00
}
def test_no_color ( self ) :
plot = gr . BarPlot (
x = " a " ,
y = " b " ,
tooltip = [ " a " , " b " ] ,
title = " Made Up Bar Plot " ,
x_title = " Variable A " ,
)
output = plot . postprocess ( simple )
assert sorted ( list ( output . keys ( ) ) ) == [ " chart " , " plot " , " type " ]
assert output [ " chart " ] == " bar "
config = json . loads ( output [ " plot " ] )
assert config [ " encoding " ] [ " x " ] [ " field " ] == " a "
assert config [ " encoding " ] [ " x " ] [ " title " ] == " Variable A "
assert config [ " encoding " ] [ " y " ] [ " field " ] == " b "
assert config [ " encoding " ] [ " y " ] [ " title " ] == " b "
assert config [ " title " ] == " Made Up Bar Plot "
assert " height " not in config
assert " width " not in config
def test_height_width ( self ) :
plot = gr . BarPlot ( x = " a " , y = " b " , height = 100 , width = 200 )
output = plot . postprocess ( simple )
assert sorted ( list ( output . keys ( ) ) ) == [ " chart " , " plot " , " type " ]
config = json . loads ( output [ " plot " ] )
assert config [ " height " ] == 100
assert config [ " width " ] == 200
output = gr . BarPlot . update ( simple , x = " a " , y = " b " , height = 100 , width = 200 )
config = json . loads ( output [ " value " ] [ " plot " ] )
assert config [ " height " ] == 100
assert config [ " width " ] == 200
def test_ylim ( self ) :
plot = gr . BarPlot ( x = " a " , y = " b " , y_lim = [ 15 , 100 ] )
output = plot . postprocess ( simple )
config = json . loads ( output [ " plot " ] )
assert config [ " encoding " ] [ " y " ] [ " scale " ] == { " domain " : [ 15 , 100 ] }
def test_horizontal ( self ) :
output = gr . BarPlot . update (
simple ,
x = " a " ,
y = " b " ,
x_title = " Variable A " ,
y_title = " Variable B " ,
title = " Simple Bar Plot with made up data " ,
tooltip = [ " a " , " b " ] ,
vertical = False ,
y_lim = [ 20 , 100 ] ,
)
assert output [ " value " ] [ " chart " ] == " bar "
config = json . loads ( output [ " value " ] [ " plot " ] )
assert config [ " encoding " ] [ " x " ] [ " field " ] == " b "
assert config [ " encoding " ] [ " x " ] [ " scale " ] == { " domain " : [ 20 , 100 ] }
assert config [ " encoding " ] [ " x " ] [ " title " ] == " Variable B "
assert config [ " encoding " ] [ " y " ] [ " field " ] == " a "
assert config [ " encoding " ] [ " y " ] [ " title " ] == " Variable A "
def test_stack_via_color ( self ) :
output = gr . BarPlot . update (
barley ,
x = " variety " ,
y = " yield " ,
color = " site " ,
title = " Barley Yield Data " ,
color_legend_title = " Site " ,
color_legend_position = " bottom " ,
)
config = json . loads ( output [ " value " ] [ " plot " ] )
assert config [ " encoding " ] [ " color " ] [ " field " ] == " site "
assert config [ " encoding " ] [ " color " ] [ " legend " ] == {
" title " : " Site " ,
" orient " : " bottom " ,
}
assert config [ " encoding " ] [ " color " ] [ " scale " ] == {
" domain " : [
" University Farm " ,
" Waseca " ,
" Morris " ,
" Crookston " ,
" Grand Rapids " ,
" Duluth " ,
] ,
" range " : [ 0 , 1 , 2 , 3 , 4 , 5 ] ,
}
def test_group ( self ) :
output = gr . BarPlot . update (
barley ,
x = " year " ,
y = " yield " ,
color = " year " ,
group = " site " ,
title = " Barley Yield by Year and Site " ,
group_title = " " ,
tooltip = [ " yield " , " site " , " year " ] ,
)
config = json . loads ( output [ " value " ] [ " plot " ] )
assert config [ " encoding " ] [ " column " ] == { " field " : " site " , " title " : " " }
def test_group_horizontal ( self ) :
output = gr . BarPlot . update (
barley ,
x = " year " ,
y = " yield " ,
color = " year " ,
group = " site " ,
title = " Barley Yield by Year and Site " ,
group_title = " Site Title " ,
tooltip = [ " yield " , " site " , " year " ] ,
vertical = False ,
)
config = json . loads ( output [ " value " ] [ " plot " ] )
assert config [ " encoding " ] [ " row " ] == { " field " : " site " , " title " : " Site Title " }
def test_barplot_accepts_fn_as_value ( self ) :
plot = gr . BarPlot (
value = lambda : barley . sample ( frac = 0.1 , replace = False ) ,
x = " year " ,
y = " yield " ,
)
assert isinstance ( plot . value , dict )
assert isinstance ( plot . value [ " plot " ] , str )
2023-03-11 01:52:17 +08:00
class TestCode :
def test_component_functions ( self ) :
"""
Preprocess , postprocess , serialize , get_config
"""
code = gr . Code ( )
assert code . preprocess ( " # hello friends " ) == " # hello friends "
assert code . preprocess ( " def fn(a): \n return a " ) == " def fn(a): \n return a "
2023-03-24 05:27:16 +08:00
assert (
code . postprocess (
"""
2023-03-21 13:07:46 +08:00
def fn ( a ) :
return a
2023-03-24 05:27:16 +08:00
"""
)
== " def fn(a): \n return a "
)
2023-03-11 01:52:17 +08:00
test_file_dir = Path ( Path ( __file__ ) . parent , " test_files " )
path = str ( Path ( test_file_dir , " test_label_json.json " ) )
with open ( path ) as f :
2023-03-14 04:59:23 +08:00
assert code . postprocess ( path ) == path
2023-03-14 08:12:41 +08:00
assert code . postprocess ( ( path , ) ) == f . read ( )
2023-03-11 01:52:17 +08:00
assert code . serialize ( " def fn(a): \n return a " ) == " def fn(a): \n return a "
assert code . deserialize ( " def fn(a): \n return a " ) == " def fn(a): \n return a "
assert code . get_config ( ) == {
" value " : None ,
" language " : None ,
" name " : " code " ,
" show_label " : True ,
" label " : None ,
" style " : { } ,
" elem_id " : None ,
2023-03-16 05:01:53 +08:00
" elem_classes " : None ,
2023-03-11 01:52:17 +08:00
" visible " : True ,
" interactive " : None ,
" root_url " : None ,
}
2023-03-22 00:37:24 +08:00
class TestTempFileManagement :
def test_hash_file ( self ) :
temp_file_manager = gr . File ( )
h1 = temp_file_manager . hash_file ( " gradio/test_data/cheetah1.jpg " )
h2 = temp_file_manager . hash_file ( " gradio/test_data/cheetah1-copy.jpg " )
h3 = temp_file_manager . hash_file ( " gradio/test_data/cheetah2.jpg " )
assert h1 == h2
assert h1 != h3
@patch ( " shutil.copy2 " )
def test_make_temp_copy_if_needed ( self , mock_copy ) :
temp_file_manager = gr . File ( )
f = temp_file_manager . make_temp_copy_if_needed ( " gradio/test_data/cheetah1.jpg " )
try : # Delete if already exists from before this test
os . remove ( f )
except OSError :
pass
f = temp_file_manager . make_temp_copy_if_needed ( " gradio/test_data/cheetah1.jpg " )
assert mock_copy . called
assert len ( temp_file_manager . temp_files ) == 1
assert Path ( f ) . name == " cheetah1.jpg "
f = temp_file_manager . make_temp_copy_if_needed ( " gradio/test_data/cheetah1.jpg " )
assert len ( temp_file_manager . temp_files ) == 1
f = temp_file_manager . make_temp_copy_if_needed (
" gradio/test_data/cheetah1-copy.jpg "
)
assert len ( temp_file_manager . temp_files ) == 2
assert Path ( f ) . name == " cheetah1-copy.jpg "
def test_base64_to_temp_file_if_needed ( self ) :
temp_file_manager = gr . File ( )
base64_file_1 = media_data . BASE64_IMAGE
base64_file_2 = media_data . BASE64_AUDIO [ " data " ]
f = temp_file_manager . base64_to_temp_file_if_needed ( base64_file_1 )
try : # Delete if already exists from before this test
os . remove ( f )
except OSError :
pass
f = temp_file_manager . base64_to_temp_file_if_needed ( base64_file_1 )
assert len ( temp_file_manager . temp_files ) == 1
f = temp_file_manager . base64_to_temp_file_if_needed ( base64_file_1 )
assert len ( temp_file_manager . temp_files ) == 1
f = temp_file_manager . base64_to_temp_file_if_needed ( base64_file_2 )
assert len ( temp_file_manager . temp_files ) == 2
for file in temp_file_manager . temp_files :
os . remove ( file )
@pytest.mark.flaky
@patch ( " shutil.copyfileobj " )
def test_download_temp_copy_if_needed ( self , mock_copy ) :
temp_file_manager = gr . File ( )
url1 = " https://raw.githubusercontent.com/gradio-app/gradio/main/gradio/test_data/test_image.png "
url2 = " https://raw.githubusercontent.com/gradio-app/gradio/main/gradio/test_data/cheetah1.jpg "
f = temp_file_manager . download_temp_copy_if_needed ( url1 )
try : # Delete if already exists from before this test
os . remove ( f )
except OSError :
pass
f = temp_file_manager . download_temp_copy_if_needed ( url1 )
assert mock_copy . called
assert len ( temp_file_manager . temp_files ) == 1
f = temp_file_manager . download_temp_copy_if_needed ( url1 )
assert len ( temp_file_manager . temp_files ) == 1
f = temp_file_manager . download_temp_copy_if_needed ( url2 )
assert len ( temp_file_manager . temp_files ) == 2