2022-03-30 04:23:30 +08:00
import json
import os
import tempfile
import unittest
2022-04-05 17:54:17 +08:00
from copy import deepcopy
2022-03-30 04:23:30 +08:00
from difflib import SequenceMatcher
import matplotlib . pyplot as plt
import numpy as np
import pandas as pd
import PIL
2022-03-30 20:34:08 +08:00
2022-03-30 04:23:30 +08:00
import gradio as gr
2022-04-06 20:55:51 +08:00
from gradio . test_data import media_data
2022-03-30 04:23:30 +08:00
os . environ [ " GRADIO_ANALYTICS_ENABLED " ] = " False "
2022-04-05 17:54:17 +08:00
"""
Tests are divided into two
2022-04-06 04:35:04 +08:00
1. test_component_functions are unit tests that check essential functions of a component , the functions that are checked are documented in the docstring .
2022-04-05 17:54:17 +08:00
2. test_in_interface_ . . . are functional tests that check a component ' s functionalities inside an Interface. Please do not use Interface.launch() in this file, as it slow downs the tests.
"""
2022-03-30 04:23:30 +08:00
2022-04-15 16:18:56 +08:00
class TestComponent ( unittest . TestCase ) :
def test_component_functions ( self ) :
"""
component
"""
assert isinstance ( gr . component ( " text " ) , gr . templates . Text )
2022-03-30 04:23:30 +08:00
class TestTextbox ( unittest . TestCase ) :
2022-04-06 04:35:04 +08:00
def test_component_functions ( self ) :
2022-04-05 17:54:17 +08:00
"""
Preprocess , postprocess , serialize , save_flagged , restore_flagged , tokenize , generate_sample , get_template_context
"""
2022-03-30 04:23:30 +08:00
text_input = gr . Textbox ( )
self . assertEqual ( text_input . preprocess ( " Hello World! " ) , " Hello World! " )
self . assertEqual ( text_input . preprocess_example ( " Hello World! " ) , " Hello World! " )
2022-04-05 17:54:17 +08:00
self . assertEqual ( text_input . postprocess ( None ) , None )
self . assertEqual ( text_input . postprocess ( " Ali " ) , " Ali " )
self . assertEqual ( text_input . postprocess ( 2 ) , " 2 " )
self . assertEqual ( text_input . postprocess ( 2.14 ) , " 2.14 " )
2022-03-30 04:23:30 +08:00
self . assertEqual ( text_input . serialize ( " Hello World! " , True ) , " Hello World! " )
with tempfile . TemporaryDirectory ( ) as tmpdirname :
to_save = text_input . save_flagged (
tmpdirname , " text_input " , " Hello World! " , None
)
self . assertEqual ( to_save , " Hello World! " )
restored = text_input . restore_flagged ( tmpdirname , to_save , None )
self . assertEqual ( restored , " Hello World! " )
with self . assertWarns ( DeprecationWarning ) :
_ = gr . Textbox ( type = " number " )
self . assertEqual (
text_input . tokenize ( " Hello World! Gradio speaking. " ) ,
(
[ " Hello " , " World! " , " Gradio " , " speaking. " ] ,
[
" World! Gradio speaking. " ,
" Hello Gradio speaking. " ,
" Hello World! speaking. " ,
" Hello World! Gradio " ,
] ,
None ,
) ,
)
text_input . interpretation_replacement = " unknown "
self . assertEqual (
text_input . tokenize ( " Hello World! Gradio speaking. " ) ,
(
[ " Hello " , " World! " , " Gradio " , " speaking. " ] ,
[
" unknown World! Gradio speaking. " ,
" Hello unknown Gradio speaking. " ,
" Hello World! unknown speaking. " ,
" Hello World! Gradio unknown " ,
] ,
None ,
) ,
)
2022-04-05 17:54:17 +08:00
self . assertEqual (
text_input . get_template_context ( ) ,
{
" lines " : 1 ,
" placeholder " : None ,
" default_value " : " " ,
" name " : " textbox " ,
" label " : None ,
" css " : { } ,
2022-04-14 22:12:30 +08:00
" interactive " : None ,
2022-04-05 17:54:17 +08:00
} ,
)
2022-03-30 04:23:30 +08:00
self . assertIsInstance ( text_input . generate_sample ( ) , str )
def test_in_interface_as_input ( self ) :
2022-04-05 17:54:17 +08:00
"""
Interface , process , interpret ,
"""
2022-03-30 04:23:30 +08:00
iface = gr . Interface ( lambda x : x [ : : - 1 ] , " textbox " , " textbox " )
2022-04-15 17:20:19 +08:00
self . assertEqual ( iface . process ( [ " Hello " ] ) , [ " olleH " ] )
2022-03-30 04:23:30 +08:00
iface = gr . Interface (
lambda sentence : max ( [ len ( word ) for word in sentence . split ( ) ] ) ,
gr . Textbox ( ) ,
" number " ,
interpretation = " default " ,
)
2022-04-06 03:58:17 +08:00
scores = iface . interpret (
[ " Return the length of the longest word in this sentence " ]
) [ 0 ] [ " interpretation " ]
2022-03-30 04:23:30 +08:00
self . assertEqual (
scores ,
[
2022-04-06 03:58:17 +08:00
( " Return " , 0.0 ) ,
( " " , 0 ) ,
( " the " , 0.0 ) ,
( " " , 0 ) ,
( " length " , 0.0 ) ,
( " " , 0 ) ,
( " of " , 0.0 ) ,
( " " , 0 ) ,
( " the " , 0.0 ) ,
( " " , 0 ) ,
( " longest " , 0.0 ) ,
( " " , 0 ) ,
( " word " , 0.0 ) ,
( " " , 0 ) ,
( " in " , 0.0 ) ,
( " " , 0 ) ,
( " this " , 0.0 ) ,
( " " , 0 ) ,
( " sentence " , 1.0 ) ,
( " " , 0 ) ,
2022-03-30 04:23:30 +08:00
] ,
)
def test_in_interface_as_output ( self ) :
2022-04-05 17:54:17 +08:00
"""
Interface , process
"""
2022-03-30 04:23:30 +08:00
iface = gr . Interface ( lambda x : x [ - 1 ] , " textbox " , gr . Textbox ( ) )
2022-04-15 17:20:19 +08:00
self . assertEqual ( iface . process ( [ " Hello " ] ) , [ " o " ] )
2022-03-30 04:23:30 +08:00
iface = gr . Interface ( lambda x : x / 2 , " number " , gr . Textbox ( ) )
2022-04-15 17:20:19 +08:00
self . assertEqual ( iface . process ( [ 10 ] ) , [ " 5.0 " ] )
2022-03-30 04:23:30 +08:00
class TestNumber ( unittest . TestCase ) :
2022-04-06 04:35:04 +08:00
def test_component_functions ( self ) :
2022-04-05 17:54:17 +08:00
"""
Preprocess , postprocess , serialize , save_flagged , restore_flagged , generate_sample , set_interpret_parameters , get_interpretation_neighbors , get_template_context
"""
2022-03-30 04:23:30 +08:00
numeric_input = gr . Number ( )
self . assertEqual ( numeric_input . preprocess ( 3 ) , 3.0 )
self . assertEqual ( numeric_input . preprocess ( None ) , None )
self . assertEqual ( numeric_input . preprocess_example ( 3 ) , 3 )
2022-04-05 17:54:17 +08:00
self . assertEqual ( numeric_input . postprocess ( 3 ) , 3.0 )
self . assertEqual ( numeric_input . postprocess ( 2.14 ) , 2.14 )
self . assertEqual ( numeric_input . postprocess ( None ) , None )
2022-03-30 04:23:30 +08:00
self . assertEqual ( numeric_input . serialize ( 3 , True ) , 3 )
with tempfile . TemporaryDirectory ( ) as tmpdirname :
to_save = numeric_input . save_flagged ( tmpdirname , " numeric_input " , 3 , None )
self . assertEqual ( to_save , 3 )
restored = numeric_input . restore_flagged ( tmpdirname , to_save , None )
self . assertEqual ( restored , 3 )
self . assertIsInstance ( numeric_input . generate_sample ( ) , float )
numeric_input . set_interpret_parameters ( steps = 3 , delta = 1 , delta_type = " absolute " )
self . assertEqual (
numeric_input . get_interpretation_neighbors ( 1 ) ,
( [ - 2.0 , - 1.0 , 0.0 , 2.0 , 3.0 , 4.0 ] , { } ) ,
)
numeric_input . set_interpret_parameters ( steps = 3 , delta = 1 , delta_type = " percent " )
self . assertEqual (
numeric_input . get_interpretation_neighbors ( 1 ) ,
( [ 0.97 , 0.98 , 0.99 , 1.01 , 1.02 , 1.03 ] , { } ) ,
)
self . assertEqual (
numeric_input . get_template_context ( ) ,
2022-04-14 22:12:30 +08:00
{
" default_value " : None ,
" name " : " number " ,
" label " : None ,
" css " : { } ,
" interactive " : None ,
} ,
2022-03-30 04:23:30 +08:00
)
2022-04-05 17:54:17 +08:00
def test_in_interface_as_input ( self ) :
"""
Interface , process , interpret
"""
2022-04-05 18:08:53 +08:00
iface = gr . Interface ( lambda x : x * * 2 , " number " , " textbox " )
2022-04-15 17:20:19 +08:00
self . assertEqual ( iface . process ( [ 2 ] ) , [ " 4.0 " ] )
2022-03-30 04:23:30 +08:00
iface = gr . Interface (
2022-04-05 18:08:53 +08:00
lambda x : x * * 2 , " number " , " number " , interpretation = " default "
2022-04-05 17:54:17 +08:00
)
2022-04-06 04:14:52 +08:00
scores = iface . interpret ( [ 2 ] ) [ 0 ] [ " interpretation " ]
2022-04-05 17:54:17 +08:00
self . assertEqual (
scores ,
[
2022-04-06 04:14:52 +08:00
( 1.94 , - 0.23640000000000017 ) ,
( 1.96 , - 0.15840000000000032 ) ,
( 1.98 , - 0.07960000000000012 ) ,
[ 2 , None ] ,
( 2.02 , 0.08040000000000003 ) ,
( 2.04 , 0.16159999999999997 ) ,
( 2.06 , 0.24359999999999982 ) ,
2022-04-05 17:54:17 +08:00
] ,
)
def test_in_interface_as_output ( self ) :
"""
Interface , process , interpret
"""
iface = gr . Interface ( lambda x : int ( x ) * * 2 , " textbox " , " number " )
2022-04-15 17:20:19 +08:00
self . assertEqual ( iface . process ( [ 2 ] ) , [ 4.0 ] )
2022-04-05 17:54:17 +08:00
iface = gr . Interface (
2022-04-05 18:08:53 +08:00
lambda x : x * * 2 , " number " , " number " , interpretation = " default "
2022-03-30 04:23:30 +08:00
)
2022-04-06 04:14:52 +08:00
scores = iface . interpret ( [ 2 ] ) [ 0 ] [ " interpretation " ]
2022-03-30 04:23:30 +08:00
self . assertEqual (
scores ,
[
2022-04-06 04:14:52 +08:00
( 1.94 , - 0.23640000000000017 ) ,
( 1.96 , - 0.15840000000000032 ) ,
( 1.98 , - 0.07960000000000012 ) ,
[ 2 , None ] ,
( 2.02 , 0.08040000000000003 ) ,
( 2.04 , 0.16159999999999997 ) ,
( 2.06 , 0.24359999999999982 ) ,
2022-03-30 04:23:30 +08:00
] ,
)
class TestSlider ( unittest . TestCase ) :
2022-04-06 04:35:04 +08:00
def test_component_functions ( self ) :
2022-04-05 17:54:17 +08:00
"""
Preprocess , postprocess , serialize , save_flagged , restore_flagged , generate_sample , get_template_context
"""
2022-03-30 04:23:30 +08:00
slider_input = gr . Slider ( )
self . assertEqual ( slider_input . preprocess ( 3.0 ) , 3.0 )
self . assertEqual ( slider_input . preprocess_example ( 3 ) , 3 )
2022-04-05 17:54:17 +08:00
self . assertEqual ( slider_input . postprocess ( 3 ) , 3 )
self . assertEqual ( slider_input . postprocess ( None ) , None )
2022-03-30 04:23:30 +08:00
self . assertEqual ( slider_input . serialize ( 3 , True ) , 3 )
with tempfile . TemporaryDirectory ( ) as tmpdirname :
to_save = slider_input . save_flagged ( tmpdirname , " slider_input " , 3 , None )
self . assertEqual ( to_save , 3 )
restored = slider_input . restore_flagged ( tmpdirname , to_save , None )
self . assertEqual ( restored , 3 )
self . assertIsInstance ( slider_input . generate_sample ( ) , int )
slider_input = gr . Slider (
default_value = 15 , minimum = 10 , maximum = 20 , step = 1 , label = " Slide Your Input "
)
self . assertEqual (
slider_input . get_template_context ( ) ,
{
" minimum " : 10 ,
" maximum " : 20 ,
" step " : 1 ,
" default_value " : 15 ,
" name " : " slider " ,
" label " : " Slide Your Input " ,
" css " : { } ,
2022-04-14 22:12:30 +08:00
" interactive " : None ,
2022-03-30 04:23:30 +08:00
} ,
)
def test_in_interface ( self ) :
2022-04-05 17:54:17 +08:00
""" "
Interface , process , interpret
"""
2022-04-05 18:08:53 +08:00
iface = gr . Interface ( lambda x : x * * 2 , " slider " , " textbox " )
2022-04-15 17:20:19 +08:00
self . assertEqual ( iface . process ( [ 2 ] ) , [ " 4 " ] )
2022-03-30 04:23:30 +08:00
iface = gr . Interface (
2022-04-05 18:08:53 +08:00
lambda x : x * * 2 , " slider " , " number " , interpretation = " default "
2022-03-30 04:23:30 +08:00
)
2022-04-06 04:14:52 +08:00
scores = iface . interpret ( [ 2 ] ) [ 0 ] [ " interpretation " ]
2022-03-30 04:23:30 +08:00
self . assertEqual (
scores ,
[
2022-04-06 04:14:52 +08:00
- 4.0 ,
200.08163265306123 ,
812.3265306122449 ,
1832.7346938775513 ,
3261.3061224489797 ,
5098.040816326531 ,
7342.938775510205 ,
9996.0 ,
2022-03-30 04:23:30 +08:00
] ,
)
class TestCheckbox ( unittest . TestCase ) :
2022-04-06 04:35:04 +08:00
def test_component_functions ( self ) :
2022-04-05 17:54:17 +08:00
"""
Preprocess , postprocess , serialize , generate_sample , get_template_context
"""
2022-03-30 04:23:30 +08:00
bool_input = gr . Checkbox ( )
self . assertEqual ( bool_input . preprocess ( True ) , True )
self . assertEqual ( bool_input . preprocess_example ( True ) , True )
2022-04-05 17:54:17 +08:00
self . assertEqual ( bool_input . postprocess ( True ) , True )
2022-03-30 04:23:30 +08:00
self . assertEqual ( bool_input . serialize ( True , True ) , True )
with tempfile . TemporaryDirectory ( ) as tmpdirname :
to_save = bool_input . save_flagged ( tmpdirname , " bool_input " , True , None )
self . assertEqual ( to_save , True )
restored = bool_input . restore_flagged ( tmpdirname , to_save , None )
self . assertEqual ( restored , True )
self . assertIsInstance ( bool_input . generate_sample ( ) , bool )
bool_input = gr . Checkbox ( default_value = True , label = " Check Your Input " )
self . assertEqual (
bool_input . get_template_context ( ) ,
{
" default_value " : True ,
" name " : " checkbox " ,
" label " : " Check Your Input " ,
" css " : { } ,
2022-04-14 22:12:30 +08:00
" interactive " : None ,
2022-03-30 04:23:30 +08:00
} ,
)
def test_in_interface ( self ) :
2022-04-05 17:54:17 +08:00
"""
Interface , process , interpret
"""
2022-03-30 04:23:30 +08:00
iface = gr . Interface ( lambda x : 1 if x else 0 , " checkbox " , " number " )
2022-04-15 17:20:19 +08:00
self . assertEqual ( iface . process ( [ True ] ) , [ 1 ] )
2022-03-30 04:23:30 +08:00
iface = gr . Interface (
lambda x : 1 if x else 0 , " checkbox " , " number " , interpretation = " default "
)
2022-04-06 04:14:52 +08:00
scores = iface . interpret ( [ False ] ) [ 0 ] [ " interpretation " ]
self . assertEqual ( scores , ( None , 1.0 ) )
scores = iface . interpret ( [ True ] ) [ 0 ] [ " interpretation " ]
self . assertEqual ( scores , ( - 1.0 , None ) )
2022-03-30 04:23:30 +08:00
class TestCheckboxGroup ( unittest . TestCase ) :
2022-04-06 04:35:04 +08:00
def test_component_functions ( self ) :
2022-04-05 17:54:17 +08:00
"""
Preprocess , preprocess_example , serialize , save_flagged , restore_flagged , generate_sample , get_template_context
"""
2022-03-30 04:23:30 +08:00
checkboxes_input = gr . CheckboxGroup ( [ " a " , " b " , " c " ] )
self . assertEqual ( checkboxes_input . preprocess ( [ " a " , " c " ] ) , [ " a " , " c " ] )
self . assertEqual ( checkboxes_input . preprocess_example ( [ " a " , " c " ] ) , [ " a " , " c " ] )
self . assertEqual ( checkboxes_input . serialize ( [ " a " , " c " ] , True ) , [ " a " , " c " ] )
with tempfile . TemporaryDirectory ( ) as tmpdirname :
to_save = checkboxes_input . save_flagged (
tmpdirname , " checkboxes_input " , [ " a " , " c " ] , None
)
self . assertEqual ( to_save , ' [ " a " , " c " ] ' )
restored = checkboxes_input . restore_flagged ( tmpdirname , to_save , None )
self . assertEqual ( restored , [ " a " , " c " ] )
self . assertIsInstance ( checkboxes_input . generate_sample ( ) , list )
checkboxes_input = gr . CheckboxGroup (
default_selected = [ " a " , " c " ] ,
choices = [ " a " , " b " , " c " ] ,
label = " Check Your Inputs " ,
)
self . assertEqual (
checkboxes_input . get_template_context ( ) ,
{
" choices " : [ " a " , " b " , " c " ] ,
" default_value " : [ " a " , " c " ] ,
" name " : " checkboxgroup " ,
" label " : " Check Your Inputs " ,
" css " : { } ,
2022-04-14 22:12:30 +08:00
" interactive " : None ,
2022-03-30 04:23:30 +08:00
} ,
)
with self . assertRaises ( ValueError ) :
wrong_type = gr . CheckboxGroup ( [ " a " ] , type = " unknown " )
wrong_type . preprocess ( 0 )
def test_in_interface ( self ) :
2022-04-05 17:54:17 +08:00
"""
Interface , process
"""
2022-03-30 04:23:30 +08:00
checkboxes_input = gr . CheckboxGroup ( [ " a " , " b " , " c " ] )
iface = gr . Interface ( lambda x : " | " . join ( x ) , checkboxes_input , " textbox " )
2022-04-15 17:20:19 +08:00
self . assertEqual ( iface . process ( [ [ " a " , " c " ] ] ) , [ " a|c " ] )
self . assertEqual ( iface . process ( [ [ ] ] ) , [ " " ] )
2022-04-05 17:54:17 +08:00
_ = gr . CheckboxGroup ( [ " a " , " b " , " c " ] , type = " index " )
2022-03-30 04:23:30 +08:00
class TestRadio ( unittest . TestCase ) :
2022-04-06 04:35:04 +08:00
def test_component_functions ( self ) :
2022-04-05 17:54:17 +08:00
"""
Preprocess , preprocess_example , serialize , save_flagged , generate_sample , get_template_context
"""
2022-03-30 04:23:30 +08:00
radio_input = gr . Radio ( [ " a " , " b " , " c " ] )
self . assertEqual ( radio_input . preprocess ( " c " ) , " c " )
self . assertEqual ( radio_input . preprocess_example ( " a " ) , " a " )
self . assertEqual ( radio_input . serialize ( " a " , True ) , " a " )
with tempfile . TemporaryDirectory ( ) as tmpdirname :
to_save = radio_input . save_flagged ( tmpdirname , " radio_input " , " a " , None )
self . assertEqual ( to_save , " a " )
restored = radio_input . restore_flagged ( tmpdirname , to_save , None )
self . assertEqual ( restored , " a " )
self . assertIsInstance ( radio_input . generate_sample ( ) , str )
radio_input = gr . Radio (
choices = [ " a " , " b " , " c " ] , default = " a " , label = " Pick Your One Input "
)
self . assertEqual (
radio_input . get_template_context ( ) ,
{
" choices " : [ " a " , " b " , " c " ] ,
" default_value " : " a " ,
" name " : " radio " ,
" label " : " Pick Your One Input " ,
" css " : { } ,
2022-04-14 22:12:30 +08:00
" interactive " : None ,
2022-03-30 04:23:30 +08:00
} ,
)
with self . assertRaises ( ValueError ) :
wrong_type = gr . Radio ( [ " a " , " b " ] , type = " unknown " )
wrong_type . preprocess ( 0 )
def test_in_interface ( self ) :
2022-04-05 17:54:17 +08:00
"""
Interface , process , interpret
"""
2022-03-30 04:23:30 +08:00
radio_input = gr . Radio ( [ " a " , " b " , " c " ] )
iface = gr . Interface ( lambda x : 2 * x , radio_input , " textbox " )
2022-04-15 17:20:19 +08:00
self . assertEqual ( iface . process ( [ " c " ] ) , [ " cc " ] )
2022-03-30 04:23:30 +08:00
radio_input = gr . Radio ( [ " a " , " b " , " c " ] , type = " index " )
iface = gr . Interface (
lambda x : 2 * x , radio_input , " number " , interpretation = " default "
)
2022-04-15 17:20:19 +08:00
self . assertEqual ( iface . process ( [ " c " ] ) , [ 4 ] )
2022-04-06 04:14:52 +08:00
scores = iface . interpret ( [ " b " ] ) [ 0 ] [ " interpretation " ]
self . assertEqual ( scores , [ - 2.0 , None , 2.0 ] )
2022-03-30 04:23:30 +08:00
class TestImage ( unittest . TestCase ) :
2022-04-06 04:35:04 +08:00
def test_component_functions ( self ) :
2022-04-05 17:54:17 +08:00
"""
Preprocess , postprocess , serialize , save_flagged , restore_flagged , generate_sample , get_template_context , _segment_by_slic
type : pil , file , filepath , numpy
"""
img = deepcopy ( media_data . BASE64_IMAGE )
2022-03-30 04:23:30 +08:00
image_input = gr . Image ( )
self . assertEqual ( image_input . preprocess ( img ) . shape , ( 68 , 61 , 3 ) )
image_input = gr . Image ( shape = ( 25 , 25 ) , image_mode = " L " )
self . assertEqual ( image_input . preprocess ( img ) . shape , ( 25 , 25 ) )
image_input = gr . Image ( shape = ( 30 , 10 ) , type = " pil " )
self . assertEqual ( image_input . preprocess ( img ) . size , ( 30 , 10 ) )
self . assertEqual ( image_input . preprocess_example ( " test/test_files/bus.png " ) , img )
self . assertEqual ( image_input . serialize ( " test/test_files/bus.png " , True ) , img )
with tempfile . TemporaryDirectory ( ) as tmpdirname :
to_save = image_input . save_flagged ( tmpdirname , " image_input " , img , None )
self . assertEqual ( " image_input/0.png " , to_save )
to_save = image_input . save_flagged ( tmpdirname , " image_input " , img , None )
self . assertEqual ( " image_input/1.png " , to_save )
restored = image_input . restore_flagged ( tmpdirname , to_save , None )
2022-04-12 04:18:42 +08:00
self . assertEqual ( restored , os . path . join ( tmpdirname , " image_input/1.png " ) )
2022-03-30 04:23:30 +08:00
self . assertIsInstance ( image_input . generate_sample ( ) , str )
image_input = gr . Image (
source = " upload " , tool = " editor " , type = " pil " , label = " Upload Your Image "
)
self . assertEqual (
image_input . get_template_context ( ) ,
{
" image_mode " : " RGB " ,
" shape " : None ,
" source " : " upload " ,
" tool " : " editor " ,
" name " : " image " ,
" label " : " Upload Your Image " ,
" css " : { } ,
" default_value " : None ,
2022-04-14 22:12:30 +08:00
" interactive " : None ,
2022-03-30 04:23:30 +08:00
} ,
)
self . assertIsNone ( image_input . preprocess ( None ) )
image_input = gr . Image ( invert_colors = True )
self . assertIsNotNone ( image_input . preprocess ( img ) )
image_input . preprocess ( img )
with self . assertWarns ( DeprecationWarning ) :
file_image = gr . Image ( type = " file " )
2022-04-05 17:54:17 +08:00
file_image . preprocess ( deepcopy ( media_data . BASE64_IMAGE ) )
2022-03-30 04:23:30 +08:00
file_image = gr . Image ( type = " filepath " )
self . assertIsInstance ( file_image . preprocess ( img ) , str )
with self . assertRaises ( ValueError ) :
wrong_type = gr . Image ( type = " unknown " )
wrong_type . preprocess ( img )
with self . assertRaises ( ValueError ) :
wrong_type = gr . Image ( type = " unknown " )
wrong_type . serialize ( " test/test_files/bus.png " , False )
img_pil = PIL . Image . open ( " test/test_files/bus.png " )
image_input = gr . Image ( type = " numpy " )
self . assertIsInstance ( image_input . serialize ( img_pil , False ) , str )
image_input = gr . Image ( type = " pil " )
self . assertIsInstance ( image_input . serialize ( img_pil , False ) , str )
image_input = gr . Image ( type = " file " )
with open ( " test/test_files/bus.png " ) as f :
self . assertEqual ( image_input . serialize ( f , False ) , img )
image_input . shape = ( 30 , 10 )
self . assertIsNotNone ( image_input . _segment_by_slic ( img ) )
2022-04-05 17:54:17 +08:00
# Output functionalities
y_img = gr . processing_utils . decode_base64_to_image (
deepcopy ( media_data . BASE64_IMAGE )
2022-03-30 04:23:30 +08:00
)
image_output = gr . Image ( )
self . assertTrue (
image_output . postprocess ( y_img ) . startswith (
"  "
)
)
self . assertTrue (
image_output . postprocess ( np . array ( y_img ) ) . startswith (
"  "
)
)
with self . assertWarns ( DeprecationWarning ) :
plot_output = gr . Image ( plot = True )
xpoints = np . array ( [ 0 , 6 ] )
ypoints = np . array ( [ 0 , 250 ] )
fig = plt . figure ( )
plt . plot ( xpoints , ypoints )
self . assertTrue (
plot_output . postprocess ( fig ) . startswith ( " data:image/png;base64, " )
)
with self . assertRaises ( ValueError ) :
image_output . postprocess ( [ 1 , 2 , 3 ] )
image_output = gr . Image ( type = " numpy " )
self . assertTrue (
image_output . postprocess ( y_img ) . startswith ( " data:image/png;base64, " )
)
with tempfile . TemporaryDirectory ( ) as tmpdirname :
to_save = image_output . save_flagged (
2022-04-05 17:54:17 +08:00
tmpdirname , " image_output " , deepcopy ( media_data . BASE64_IMAGE ) , None
2022-03-30 04:23:30 +08:00
)
self . assertEqual ( " image_output/0.png " , to_save )
to_save = image_output . save_flagged (
2022-04-05 17:54:17 +08:00
tmpdirname , " image_output " , deepcopy ( media_data . BASE64_IMAGE ) , None
2022-03-30 04:23:30 +08:00
)
self . assertEqual ( " image_output/1.png " , to_save )
2022-04-05 17:54:17 +08:00
def test_in_interface_as_input ( self ) :
"""
Interface , process , interpret
type : file
interpretation : default , shap ,
"""
img = deepcopy ( media_data . BASE64_IMAGE )
image_input = gr . Image ( )
iface = gr . Interface (
lambda x : PIL . Image . open ( x ) . rotate ( 90 , expand = True ) ,
gr . Image ( shape = ( 30 , 10 ) , type = " file " ) ,
" image " ,
)
2022-04-15 17:20:19 +08:00
output = iface . process ( [ img ] ) [ 0 ]
2022-04-05 17:54:17 +08:00
self . assertEqual (
gr . processing_utils . decode_base64_to_image ( output ) . size , ( 10 , 30 )
)
iface = gr . Interface (
lambda x : np . sum ( x ) , image_input , " number " , interpretation = " default "
)
2022-04-06 04:14:52 +08:00
scores = iface . interpret ( [ img ] ) [ 0 ] [ " interpretation " ]
2022-04-05 17:54:17 +08:00
self . assertEqual (
2022-04-06 04:14:52 +08:00
scores , deepcopy ( media_data . SUM_PIXELS_INTERPRETATION ) [ " scores " ] [ 0 ]
2022-04-05 17:54:17 +08:00
)
iface = gr . Interface (
lambda x : np . sum ( x ) , image_input , " label " , interpretation = " shap "
)
2022-04-06 04:14:52 +08:00
scores = iface . interpret ( [ img ] ) [ 0 ] [ " interpretation " ]
2022-04-05 17:54:17 +08:00
self . assertEqual (
len ( scores [ 0 ] ) ,
2022-04-06 04:14:52 +08:00
len ( deepcopy ( media_data . SUM_PIXELS_SHAP_INTERPRETATION ) [ " scores " ] [ 0 ] [ 0 ] ) ,
2022-04-05 17:54:17 +08:00
)
image_input = gr . Image ( shape = ( 30 , 10 ) )
iface = gr . Interface (
lambda x : np . sum ( x ) , image_input , " number " , interpretation = " default "
)
self . assertIsNotNone ( iface . interpret ( [ img ] ) )
2022-03-30 04:23:30 +08:00
def test_in_interface_as_output ( self ) :
2022-04-05 17:54:17 +08:00
"""
Interface , process
"""
2022-03-30 04:23:30 +08:00
def generate_noise ( width , height ) :
return np . random . randint ( 0 , 256 , ( width , height , 3 ) )
iface = gr . Interface ( generate_noise , [ " slider " , " slider " ] , " image " )
2022-04-15 17:20:19 +08:00
self . assertTrue ( iface . process ( [ 10 , 20 ] ) [ 0 ] . startswith ( " data:image/png;base64 " ) )
2022-03-30 04:23:30 +08:00
class TestAudio ( unittest . TestCase ) :
2022-04-06 04:35:04 +08:00
def test_component_functions ( self ) :
2022-04-05 17:54:17 +08:00
"""
Preprocess , postprocess serialize , save_flagged , restore_flagged , generate_sample , get_template_context , deserialize
type : filepath , numpy , file
"""
x_wav = deepcopy ( media_data . BASE64_AUDIO )
2022-03-30 04:23:30 +08:00
audio_input = gr . Audio ( )
output = audio_input . preprocess ( x_wav )
self . assertEqual ( output [ 0 ] , 8000 )
self . assertEqual ( output [ 1 ] . shape , ( 8046 , ) )
self . assertEqual (
audio_input . serialize ( " test/test_files/audio_sample.wav " , True ) [ " data " ] ,
x_wav [ " data " ] ,
)
with tempfile . TemporaryDirectory ( ) as tmpdirname :
to_save = audio_input . save_flagged ( tmpdirname , " audio_input " , x_wav , None )
self . assertEqual ( " audio_input/0.wav " , to_save )
to_save = audio_input . save_flagged ( tmpdirname , " audio_input " , x_wav , None )
self . assertEqual ( " audio_input/1.wav " , to_save )
restored = audio_input . restore_flagged ( tmpdirname , to_save , None )
self . assertEqual ( restored , " audio_input/1.wav " )
self . assertIsInstance ( audio_input . generate_sample ( ) , dict )
audio_input = gr . Audio ( label = " Upload Your Audio " )
self . assertEqual (
audio_input . get_template_context ( ) ,
{
" source " : " upload " ,
" name " : " audio " ,
" label " : " Upload Your Audio " ,
" css " : { } ,
" default_value " : None ,
2022-04-14 22:12:30 +08:00
" interactive " : None ,
2022-03-30 04:23:30 +08:00
} ,
)
self . assertIsNone ( audio_input . preprocess ( None ) )
x_wav [ " is_example " ] = True
x_wav [ " crop_min " ] , x_wav [ " crop_max " ] = 1 , 4
self . assertIsNotNone ( audio_input . preprocess ( x_wav ) )
with self . assertWarns ( DeprecationWarning ) :
audio_input = gr . Audio ( type = " file " )
audio_input . preprocess ( x_wav )
with open ( " test/test_files/audio_sample.wav " ) as f :
audio_input . serialize ( f , False )
audio_input = gr . Audio ( type = " filepath " )
self . assertIsInstance ( audio_input . preprocess ( x_wav ) , str )
with self . assertRaises ( ValueError ) :
audio_input = gr . Audio ( type = " unknown " )
audio_input . preprocess ( x_wav )
audio_input . serialize ( x_wav , False )
audio_input = gr . Audio ( type = " numpy " )
x_wav = gr . processing_utils . audio_from_file ( " test/test_files/audio_sample.wav " )
self . assertIsInstance ( audio_input . serialize ( x_wav , False ) , dict )
2022-04-05 17:54:17 +08:00
# Output functionalities
2022-03-30 04:23:30 +08:00
y_audio = gr . processing_utils . decode_base64_to_file (
2022-04-05 17:54:17 +08:00
deepcopy ( media_data . BASE64_AUDIO ) [ " data " ]
2022-03-30 04:23:30 +08:00
)
audio_output = gr . Audio ( type = " file " )
self . assertTrue (
audio_output . postprocess ( y_audio . name ) . startswith (
" data:audio/wav;base64,UklGRuI/AABXQVZFZm10IBAAA "
)
)
self . assertEqual (
audio_output . get_template_context ( ) ,
{
" name " : " audio " ,
" label " : None ,
" source " : " upload " ,
" css " : { } ,
" default_value " : None ,
2022-04-14 22:12:30 +08:00
" interactive " : None ,
2022-03-30 04:23:30 +08:00
} ,
)
self . assertTrue (
2022-04-05 17:54:17 +08:00
audio_output . deserialize (
deepcopy ( media_data . BASE64_AUDIO ) [ " data " ]
) . endswith ( " .wav " )
2022-03-30 04:23:30 +08:00
)
with tempfile . TemporaryDirectory ( ) as tmpdirname :
to_save = audio_output . save_flagged (
2022-04-05 17:54:17 +08:00
tmpdirname , " audio_output " , deepcopy ( media_data . BASE64_AUDIO ) , None
2022-03-30 04:23:30 +08:00
)
self . assertEqual ( " audio_output/0.wav " , to_save )
to_save = audio_output . save_flagged (
2022-04-05 17:54:17 +08:00
tmpdirname , " audio_output " , deepcopy ( media_data . BASE64_AUDIO ) , None
2022-03-30 04:23:30 +08:00
)
self . assertEqual ( " audio_output/1.wav " , to_save )
2022-04-05 17:54:17 +08:00
def test_tokenize ( self ) :
"""
Tokenize , get_masked_inputs
"""
x_wav = deepcopy ( media_data . BASE64_AUDIO )
audio_input = gr . Audio ( )
tokens , _ , _ = audio_input . tokenize ( x_wav )
self . assertEquals ( len ( tokens ) , audio_input . interpretation_segments )
x_new = audio_input . get_masked_inputs ( tokens , [ [ 1 ] * len ( tokens ) ] ) [ 0 ]
similarity = SequenceMatcher ( a = x_wav [ " data " ] , b = x_new ) . ratio ( )
self . assertGreater ( similarity , 0.9 )
2022-04-07 06:40:28 +08:00
def test_in_interface ( self ) :
def reverse_audio ( audio ) :
sr , data = audio
return ( sr , np . flipud ( data ) )
iface = gr . Interface ( reverse_audio , " audio " , " audio " )
2022-04-15 17:20:19 +08:00
reversed_data = iface . process ( [ deepcopy ( media_data . BASE64_AUDIO ) ] ) [ 0 ]
2022-04-07 06:40:28 +08:00
reversed_input = { " name " : " fake_name " , " data " : reversed_data }
self . assertTrue ( reversed_data . startswith ( " data:audio/wav;base64,UklGRgA/ " ) )
self . assertTrue (
2022-04-15 17:20:19 +08:00
iface . process ( [ deepcopy ( media_data . BASE64_AUDIO ) ] ) [ 0 ] . startswith (
2022-04-07 06:40:28 +08:00
" data:audio/wav;base64,UklGRgA/ "
)
)
self . maxDiff = None
2022-04-15 17:20:19 +08:00
reversed_reversed_data = iface . process ( [ reversed_input ] ) [ 0 ]
2022-04-07 06:40:28 +08:00
similarity = SequenceMatcher (
a = reversed_reversed_data , b = media_data . BASE64_AUDIO [ " data " ]
) . ratio ( )
self . assertGreater ( similarity , 0.99 )
2022-03-30 04:23:30 +08:00
def test_in_interface_as_output ( self ) :
2022-04-05 17:54:17 +08:00
"""
Interface , process
"""
2022-03-30 04:23:30 +08:00
def generate_noise ( duration ) :
return 48000 , np . random . randint ( - 256 , 256 , ( duration , 3 ) ) . astype ( np . int16 )
iface = gr . Interface ( generate_noise , " slider " , " audio " )
2022-04-15 17:20:19 +08:00
self . assertTrue ( iface . process ( [ 100 ] ) [ 0 ] . startswith ( " data:audio/wav;base64 " ) )
2022-03-30 04:23:30 +08:00
class TestFile ( unittest . TestCase ) :
2022-04-06 04:35:04 +08:00
def test_component_functions ( self ) :
2022-04-05 17:54:17 +08:00
"""
Preprocess , serialize , save_flagged , restore_flagged , generate_sample , get_template_context , default_value
"""
x_file = deepcopy ( media_data . BASE64_FILE )
2022-03-30 04:23:30 +08:00
file_input = gr . File ( )
output = file_input . preprocess ( x_file )
self . assertIsInstance ( output , tempfile . _TemporaryFileWrapper )
self . assertEqual (
file_input . serialize ( " test/test_files/sample_file.pdf " , True ) ,
" test/test_files/sample_file.pdf " ,
)
with tempfile . TemporaryDirectory ( ) as tmpdirname :
to_save = file_input . save_flagged ( tmpdirname , " file_input " , [ x_file ] , None )
self . assertEqual ( " file_input/0 " , to_save )
to_save = file_input . save_flagged ( tmpdirname , " file_input " , [ x_file ] , None )
self . assertEqual ( " file_input/1 " , to_save )
restored = file_input . restore_flagged ( tmpdirname , to_save , None )
self . assertEqual ( restored , " file_input/1 " )
self . assertIsInstance ( file_input . generate_sample ( ) , dict )
file_input = gr . File ( label = " Upload Your File " )
self . assertEqual (
file_input . get_template_context ( ) ,
{
" file_count " : " single " ,
" name " : " file " ,
" label " : " Upload Your File " ,
" css " : { } ,
" default_value " : None ,
2022-04-14 22:12:30 +08:00
" interactive " : None ,
2022-03-30 04:23:30 +08:00
} ,
)
self . assertIsNone ( file_input . preprocess ( None ) )
x_file [ " is_example " ] = True
self . assertIsNotNone ( file_input . preprocess ( x_file ) )
2022-04-05 17:54:17 +08:00
file_input = gr . File ( " test/test_files/sample_file.pdf " )
self . assertEqual (
file_input . get_template_context ( ) ,
deepcopy ( media_data . FILE_TEMPLATE_CONTEXT ) ,
)
2022-03-30 04:23:30 +08:00
def test_in_interface_as_input ( self ) :
2022-04-05 17:54:17 +08:00
"""
Interface , process
"""
x_file = deepcopy ( media_data . BASE64_FILE )
2022-03-30 04:23:30 +08:00
def get_size_of_file ( file_obj ) :
return os . path . getsize ( file_obj . name )
iface = gr . Interface ( get_size_of_file , " file " , " number " )
2022-04-15 17:20:19 +08:00
self . assertEqual ( iface . process ( [ [ x_file ] ] ) , [ 10558 ] )
2022-03-30 04:23:30 +08:00
def test_as_component_as_output ( self ) :
2022-04-05 17:54:17 +08:00
"""
Interface , process , save_flagged ,
"""
2022-03-30 04:23:30 +08:00
def write_file ( content ) :
with open ( " test.txt " , " w " ) as f :
f . write ( content )
return " test.txt "
iface = gr . Interface ( write_file , " text " , " file " )
self . assertDictEqual (
2022-04-15 17:20:19 +08:00
iface . process ( [ " hello world " ] ) [ 0 ] ,
2022-03-30 04:23:30 +08:00
{
" name " : " test.txt " ,
" size " : 11 ,
" data " : " data:text/plain;base64,aGVsbG8gd29ybGQ= " ,
} ,
)
file_output = gr . File ( )
with tempfile . TemporaryDirectory ( ) as tmpdirname :
to_save = file_output . save_flagged (
2022-04-05 17:54:17 +08:00
tmpdirname , " file_output " , [ deepcopy ( media_data . BASE64_FILE ) ] , None
2022-03-30 04:23:30 +08:00
)
self . assertEqual ( " file_output/0 " , to_save )
to_save = file_output . save_flagged (
2022-04-05 17:54:17 +08:00
tmpdirname , " file_output " , [ deepcopy ( media_data . BASE64_FILE ) ] , None
2022-03-30 04:23:30 +08:00
)
self . assertEqual ( " file_output/1 " , to_save )
class TestDataframe ( unittest . TestCase ) :
2022-04-06 04:35:04 +08:00
def test_component_functions ( self ) :
2022-04-05 17:54:17 +08:00
"""
Preprocess , serialize , save_flagged , restore_flagged , generate_sample , get_template_context
"""
2022-03-30 04:23:30 +08:00
x_data = [ [ " Tim " , 12 , False ] , [ " Jan " , 24 , True ] ]
dataframe_input = gr . Dataframe ( headers = [ " Name " , " Age " , " Member " ] )
output = dataframe_input . preprocess ( x_data )
self . assertEqual ( output [ " Age " ] [ 1 ] , 24 )
self . assertEqual ( output [ " Member " ] [ 0 ] , False )
self . assertEqual ( dataframe_input . preprocess_example ( x_data ) , x_data )
self . assertEqual ( dataframe_input . serialize ( x_data , True ) , x_data )
with tempfile . TemporaryDirectory ( ) as tmpdirname :
to_save = dataframe_input . save_flagged (
tmpdirname , " dataframe_input " , x_data , None
)
self . assertEqual ( json . dumps ( x_data ) , to_save )
restored = dataframe_input . restore_flagged ( tmpdirname , to_save , None )
self . assertEqual ( x_data , restored )
self . assertIsInstance ( dataframe_input . generate_sample ( ) , list )
dataframe_input = gr . Dataframe (
headers = [ " Name " , " Age " , " Member " ] , label = " Dataframe Input "
)
self . assertEqual (
dataframe_input . get_template_context ( ) ,
{
" headers " : [ " Name " , " Age " , " Member " ] ,
" datatype " : " str " ,
" row_count " : 3 ,
" col_count " : 3 ,
" col_width " : None ,
" default_value " : [
2022-04-14 04:53:35 +08:00
[ " " , " " , " " ] ,
[ " " , " " , " " ] ,
[ " " , " " , " " ] ,
2022-03-30 04:23:30 +08:00
] ,
" name " : " dataframe " ,
" label " : " Dataframe Input " ,
" max_rows " : 20 ,
" max_cols " : None ,
" overflow_row_behaviour " : " paginate " ,
" css " : { } ,
2022-04-14 22:12:30 +08:00
" interactive " : None ,
2022-03-30 04:23:30 +08:00
} ,
)
dataframe_input = gr . Dataframe ( )
output = dataframe_input . preprocess ( x_data )
self . assertEqual ( output [ 1 ] [ 1 ] , 24 )
with self . assertRaises ( ValueError ) :
wrong_type = gr . Dataframe ( type = " unknown " )
wrong_type . preprocess ( x_data )
2022-04-05 17:54:17 +08:00
# Output functionalities
2022-03-30 04:23:30 +08:00
dataframe_output = gr . Dataframe ( )
output = dataframe_output . postprocess ( np . zeros ( ( 2 , 2 ) ) )
self . assertDictEqual ( output , { " data " : [ [ 0 , 0 ] , [ 0 , 0 ] ] } )
output = dataframe_output . postprocess ( [ [ 1 , 3 , 5 ] ] )
self . assertDictEqual ( output , { " data " : [ [ 1 , 3 , 5 ] ] } )
output = dataframe_output . postprocess (
pd . DataFrame ( [ [ 2 , True ] , [ 3 , True ] , [ 4 , False ] ] , columns = [ " num " , " prime " ] )
)
self . assertDictEqual (
output ,
2022-04-07 06:40:28 +08:00
{
" headers " : [ " num " , " prime " ] ,
" data " : [ [ 2 , True ] , [ 3 , True ] , [ 4 , False ] ] ,
} ,
2022-03-30 04:23:30 +08:00
)
self . assertEqual (
dataframe_output . get_template_context ( ) ,
{
" headers " : None ,
" max_rows " : 20 ,
" max_cols " : None ,
" overflow_row_behaviour " : " paginate " ,
" name " : " dataframe " ,
" label " : None ,
" css " : { } ,
" datatype " : " str " ,
" row_count " : 3 ,
" col_count " : 3 ,
" col_width " : None ,
" default_value " : [
2022-04-14 04:53:35 +08:00
[ " " , " " , " " ] ,
[ " " , " " , " " ] ,
[ " " , " " , " " ] ,
2022-03-30 04:23:30 +08:00
] ,
2022-04-14 22:12:30 +08:00
" interactive " : None ,
2022-03-30 04:23:30 +08:00
} ,
)
with self . assertRaises ( ValueError ) :
wrong_type = gr . Dataframe ( type = " unknown " )
wrong_type . postprocess ( 0 )
with tempfile . TemporaryDirectory ( ) as tmpdirname :
to_save = dataframe_output . save_flagged (
tmpdirname , " dataframe_output " , output , None
)
self . assertEqual (
to_save ,
json . dumps (
{
" headers " : [ " num " , " prime " ] ,
" data " : [ [ 2 , True ] , [ 3 , True ] , [ 4 , False ] ] ,
}
) ,
)
self . assertEqual (
dataframe_output . restore_flagged ( tmpdirname , to_save , None ) ,
{
" headers " : [ " num " , " prime " ] ,
" data " : [ [ 2 , True ] , [ 3 , True ] , [ 4 , False ] ] ,
} ,
)
2022-04-05 17:54:17 +08:00
def test_in_interface_as_input ( self ) :
"""
Interface , process ,
"""
x_data = [ [ 1 , 2 , 3 ] , [ 4 , 5 , 6 ] ]
iface = gr . Interface ( np . max , " numpy " , " number " )
2022-04-15 17:20:19 +08:00
self . assertEqual ( iface . process ( [ x_data ] ) , [ 6 ] )
2022-04-05 17:54:17 +08:00
x_data = [ [ " Tim " ] , [ " Jon " ] , [ " Sal " ] ]
def get_last ( my_list ) :
return my_list [ - 1 ]
iface = gr . Interface ( get_last , " list " , " text " )
2022-04-15 17:20:19 +08:00
self . assertEqual ( iface . process ( [ x_data ] ) , [ " Sal " ] )
2022-04-05 17:54:17 +08:00
2022-03-30 04:23:30 +08:00
def test_in_interface_as_output ( self ) :
2022-04-05 17:54:17 +08:00
"""
Interface , process
"""
2022-03-30 04:23:30 +08:00
def check_odd ( array ) :
return array % 2 == 0
iface = gr . Interface ( check_odd , " numpy " , " numpy " )
2022-04-15 17:20:19 +08:00
self . assertEqual ( iface . process ( [ [ 2 , 3 , 4 ] ] ) [ 0 ] , { " data " : [ [ True , False , True ] ] } )
2022-03-30 04:23:30 +08:00
class TestVideo ( unittest . TestCase ) :
2022-04-06 04:35:04 +08:00
def test_component_functions ( self ) :
2022-04-05 17:54:17 +08:00
"""
Preprocess , serialize , deserialize , save_flagged , restore_flagged , generate_sample , get_template_context
"""
x_video = deepcopy ( media_data . BASE64_VIDEO )
2022-03-30 04:23:30 +08:00
video_input = gr . Video ( )
output = video_input . preprocess ( x_video )
self . assertIsInstance ( output , str )
with tempfile . TemporaryDirectory ( ) as tmpdirname :
to_save = video_input . save_flagged ( tmpdirname , " video_input " , x_video , None )
self . assertEqual ( " video_input/0.mp4 " , to_save )
to_save = video_input . save_flagged ( tmpdirname , " video_input " , x_video , None )
self . assertEqual ( " video_input/1.mp4 " , to_save )
restored = video_input . restore_flagged ( tmpdirname , to_save , None )
self . assertEqual ( restored , " video_input/1.mp4 " )
self . assertIsInstance ( video_input . generate_sample ( ) , dict )
video_input = gr . Video ( label = " Upload Your Video " )
self . assertEqual (
video_input . get_template_context ( ) ,
{
" source " : " upload " ,
" name " : " video " ,
" label " : " Upload Your Video " ,
" css " : { } ,
" default_value " : None ,
2022-04-14 22:12:30 +08:00
" interactive " : None ,
2022-03-30 04:23:30 +08:00
} ,
)
self . assertIsNone ( video_input . preprocess ( None ) )
x_video [ " is_example " ] = True
self . assertIsNotNone ( video_input . preprocess ( x_video ) )
video_input = gr . Video ( type = " avi " )
2022-04-05 17:54:17 +08:00
self . assertEqual ( video_input . preprocess ( x_video ) [ - 3 : ] , " avi " )
2022-03-30 04:23:30 +08:00
with self . assertRaises ( NotImplementedError ) :
video_input . serialize ( x_video , True )
2022-04-05 17:54:17 +08:00
# Output functionalities
y_vid_path = " test/test_files/video_sample.mp4 "
2022-03-30 04:23:30 +08:00
video_output = gr . Video ( )
self . assertTrue (
2022-04-05 17:54:17 +08:00
video_output . postprocess ( y_vid_path ) [ " data " ] . startswith (
" data:video/mp4;base64, "
)
2022-03-30 04:23:30 +08:00
)
self . assertTrue (
2022-04-05 17:54:17 +08:00
video_output . deserialize (
deepcopy ( media_data . BASE64_VIDEO ) [ " data " ]
) . endswith ( " .mp4 " )
2022-03-30 04:23:30 +08:00
)
with tempfile . TemporaryDirectory ( ) as tmpdirname :
to_save = video_output . save_flagged (
2022-04-05 17:54:17 +08:00
tmpdirname , " video_output " , deepcopy ( media_data . BASE64_VIDEO ) , None
2022-03-30 04:23:30 +08:00
)
self . assertEqual ( " video_output/0.mp4 " , to_save )
to_save = video_output . save_flagged (
2022-04-05 17:54:17 +08:00
tmpdirname , " video_output " , deepcopy ( media_data . BASE64_VIDEO ) , None
2022-03-30 04:23:30 +08:00
)
self . assertEqual ( " video_output/1.mp4 " , to_save )
2022-04-07 06:40:28 +08:00
def test_in_interface ( self ) :
2022-04-05 17:54:17 +08:00
"""
Interface , process
"""
x_video = deepcopy ( media_data . BASE64_VIDEO )
iface = gr . Interface ( lambda x : x , " video " , " playable_video " )
2022-04-15 17:20:19 +08:00
self . assertEqual ( iface . process ( [ x_video ] ) [ 0 ] [ " data " ] , x_video [ " data " ] )
2022-04-05 17:54:17 +08:00
2022-03-30 04:23:30 +08:00
class TestTimeseries ( unittest . TestCase ) :
2022-04-06 04:35:04 +08:00
def test_component_functions ( self ) :
2022-04-05 17:54:17 +08:00
"""
Preprocess , postprocess , save_flagged , restore_flagged , generate_sample , get_template_context ,
"""
2022-03-30 04:23:30 +08:00
timeseries_input = gr . Timeseries ( x = " time " , y = [ " retail " , " food " , " other " ] )
x_timeseries = {
" data " : [ [ 1 ] + [ 2 ] * len ( timeseries_input . y ) ] * 4 ,
" headers " : [ timeseries_input . x ] + timeseries_input . y ,
}
output = timeseries_input . preprocess ( x_timeseries )
self . assertIsInstance ( output , pd . core . frame . DataFrame )
with tempfile . TemporaryDirectory ( ) as tmpdirname :
to_save = timeseries_input . save_flagged (
tmpdirname , " video_input " , x_timeseries , None
)
self . assertEqual ( json . dumps ( x_timeseries ) , to_save )
restored = timeseries_input . restore_flagged ( tmpdirname , to_save , None )
self . assertEqual ( x_timeseries , restored )
self . assertIsInstance ( timeseries_input . generate_sample ( ) , dict )
timeseries_input = gr . Timeseries (
x = " time " , y = " retail " , label = " Upload Your Timeseries "
)
self . assertEqual (
timeseries_input . get_template_context ( ) ,
{
" x " : " time " ,
" y " : [ " retail " ] ,
" name " : " timeseries " ,
" label " : " Upload Your Timeseries " ,
" css " : { } ,
" default_value " : None ,
2022-04-14 22:12:30 +08:00
" interactive " : None ,
2022-03-30 04:23:30 +08:00
} ,
)
self . assertIsNone ( timeseries_input . preprocess ( None ) )
x_timeseries [ " range " ] = ( 0 , 1 )
self . assertIsNotNone ( timeseries_input . preprocess ( x_timeseries ) )
2022-04-05 17:54:17 +08:00
# Output functionalities
2022-03-30 04:23:30 +08:00
timeseries_output = gr . Timeseries ( label = " Disease " )
self . assertEqual (
timeseries_output . get_template_context ( ) ,
{
" x " : None ,
" y " : None ,
" name " : " timeseries " ,
" label " : " Disease " ,
" css " : { } ,
" default_value " : None ,
2022-04-14 22:12:30 +08:00
" interactive " : None ,
2022-03-30 04:23:30 +08:00
} ,
)
data = { " Name " : [ " Tom " , " nick " , " krish " , " jack " ] , " Age " : [ 20 , 21 , 19 , 18 ] }
df = pd . DataFrame ( data )
self . assertEqual (
timeseries_output . postprocess ( df ) ,
{
" headers " : [ " Name " , " Age " ] ,
" data " : [ [ " Tom " , 20 ] , [ " nick " , 21 ] , [ " krish " , 19 ] , [ " jack " , 18 ] ] ,
} ,
)
timeseries_output = gr . Timeseries ( y = " Age " , label = " Disease " )
output = timeseries_output . postprocess ( df )
self . assertEqual (
output ,
{
" headers " : [ " Name " , " Age " ] ,
" data " : [ [ " Tom " , 20 ] , [ " nick " , 21 ] , [ " krish " , 19 ] , [ " jack " , 18 ] ] ,
} ,
)
with tempfile . TemporaryDirectory ( ) as tmpdirname :
to_save = timeseries_output . save_flagged (
tmpdirname , " timeseries_output " , output , None
)
self . assertEqual (
to_save ,
' { " headers " : [ " Name " , " Age " ], " data " : [[ " Tom " , 20], [ " nick " , 21], [ " krish " , 19], '
' [ " jack " , 18]]} ' ,
)
self . assertEqual (
timeseries_output . restore_flagged ( tmpdirname , to_save , None ) ,
{
" headers " : [ " Name " , " Age " ] ,
2022-04-07 06:40:28 +08:00
" data " : [
[ " Tom " , 20 ] ,
[ " nick " , 21 ] ,
[ " krish " , 19 ] ,
[ " jack " , 18 ] ,
] ,
2022-03-30 04:23:30 +08:00
} ,
)
2022-04-07 06:40:28 +08:00
def test_in_interface_as_input ( self ) :
2022-04-05 17:54:17 +08:00
"""
Interface , process
"""
timeseries_input = gr . Timeseries ( x = " time " , y = [ " retail " , " food " , " other " ] )
x_timeseries = {
" data " : [ [ 1 ] + [ 2 ] * len ( timeseries_input . y ) ] * 4 ,
" headers " : [ timeseries_input . x ] + timeseries_input . y ,
}
iface = gr . Interface ( lambda x : x , timeseries_input , " dataframe " )
self . assertEqual (
2022-04-15 17:20:19 +08:00
iface . process ( [ x_timeseries ] ) ,
2022-04-05 17:54:17 +08:00
[
{
" headers " : [ " time " , " retail " , " food " , " other " ] ,
2022-04-07 06:40:28 +08:00
" data " : [
[ 1 , 2 , 2 , 2 ] ,
[ 1 , 2 , 2 , 2 ] ,
[ 1 , 2 , 2 , 2 ] ,
[ 1 , 2 , 2 , 2 ] ,
] ,
2022-04-05 17:54:17 +08:00
}
] ,
)
2022-04-07 22:55:38 +08:00
def test_in_interface_as_output ( self ) :
"""
Interface , process
"""
timeseries_output = gr . Timeseries ( x = " time " , y = [ " retail " , " food " , " other " ] )
iface = gr . Interface ( lambda x : x , " dataframe " , timeseries_output )
2022-04-07 23:23:55 +08:00
df = pd . DataFrame (
{
" time " : [ 1 , 2 , 3 , 4 ] ,
" retail " : [ 1 , 2 , 3 , 2 ] ,
" food " : [ 1 , 2 , 3 , 2 ] ,
" other " : [ 1 , 2 , 4 , 2 ] ,
}
)
2022-04-07 22:55:38 +08:00
self . assertEqual (
2022-04-15 17:20:19 +08:00
iface . process ( [ df ] ) ,
2022-04-07 22:55:38 +08:00
[
{
" headers " : [ " time " , " retail " , " food " , " other " ] ,
" data " : [
[ 1 , 1 , 1 , 1 ] ,
[ 2 , 2 , 2 , 2 ] ,
[ 3 , 3 , 3 , 4 ] ,
[ 4 , 2 , 2 , 2 ] ,
] ,
}
] ,
)
2022-04-07 06:40:28 +08:00
2022-03-30 04:23:30 +08:00
class TestNames ( unittest . TestCase ) :
2022-04-05 17:54:17 +08:00
# This test ensures that `components.get_component_instance()` works correctly when instantiating from components
2022-03-30 04:23:30 +08:00
def test_no_duplicate_uncased_names ( self ) :
subclasses = gr . components . Component . __subclasses__ ( )
unique_subclasses_uncased = set ( [ s . __name__ . lower ( ) for s in subclasses ] )
self . assertEqual ( len ( subclasses ) , len ( unique_subclasses_uncased ) )
class TestLabel ( unittest . TestCase ) :
2022-04-06 04:35:04 +08:00
def test_component_functions ( self ) :
2022-04-05 17:54:17 +08:00
"""
Process , postprocess , deserialize , save_flagged , restore_flagged
"""
2022-03-30 04:23:30 +08:00
y = " happy "
label_output = gr . Label ( )
label = label_output . postprocess ( y )
self . assertDictEqual ( label , { " label " : " happy " } )
self . assertEqual ( label_output . deserialize ( y ) , y )
self . assertEqual ( label_output . deserialize ( label ) , y )
with tempfile . TemporaryDirectory ( ) as tmpdir :
to_save = label_output . save_flagged ( tmpdir , " label_output " , label , None )
self . assertEqual ( to_save , y )
y = { 3 : 0.7 , 1 : 0.2 , 0 : 0.1 }
label_output = gr . Label ( )
label = label_output . postprocess ( y )
self . assertDictEqual (
label ,
{
" label " : 3 ,
" confidences " : [
{ " label " : 3 , " confidence " : 0.7 } ,
{ " label " : 1 , " confidence " : 0.2 } ,
{ " label " : 0 , " confidence " : 0.1 } ,
] ,
} ,
)
label_output = gr . Label ( num_top_classes = 2 )
label = label_output . postprocess ( y )
self . assertDictEqual (
label ,
{
" label " : 3 ,
" confidences " : [
{ " label " : 3 , " confidence " : 0.7 } ,
{ " label " : 1 , " confidence " : 0.2 } ,
] ,
} ,
)
with self . assertRaises ( ValueError ) :
label_output . postprocess ( [ 1 , 2 , 3 ] )
with tempfile . TemporaryDirectory ( ) as tmpdir :
to_save = label_output . save_flagged ( tmpdir , " label_output " , label , None )
self . assertEqual ( to_save , ' { " 3 " : 0.7, " 1 " : 0.2} ' )
self . assertEqual (
label_output . restore_flagged ( tmpdir , to_save , None ) ,
{
" label " : " 3 " ,
" confidences " : [
{ " label " : " 3 " , " confidence " : 0.7 } ,
{ " label " : " 1 " , " confidence " : 0.2 } ,
] ,
} ,
)
2022-04-05 17:54:17 +08:00
self . assertEqual (
label_output . get_template_context ( ) ,
2022-04-14 22:12:30 +08:00
{
" name " : " label " ,
" label " : None ,
" css " : { } ,
" interactive " : None ,
} ,
2022-04-05 17:54:17 +08:00
)
2022-03-30 04:23:30 +08:00
def test_in_interface ( self ) :
2022-04-05 17:54:17 +08:00
"""
Interface , process
"""
x_img = deepcopy ( media_data . BASE64_IMAGE )
2022-03-30 04:23:30 +08:00
def rgb_distribution ( img ) :
rgb_dist = np . mean ( img , axis = ( 0 , 1 ) )
rgb_dist / = np . sum ( rgb_dist )
rgb_dist = np . round ( rgb_dist , decimals = 2 )
return {
" red " : rgb_dist [ 0 ] ,
" green " : rgb_dist [ 1 ] ,
" blue " : rgb_dist [ 2 ] ,
}
iface = gr . Interface ( rgb_distribution , " image " , " label " )
2022-04-15 17:20:19 +08:00
output = iface . process ( [ x_img ] ) [ 0 ]
2022-03-30 04:23:30 +08:00
self . assertDictEqual (
output ,
{
" label " : " red " ,
" confidences " : [
{ " label " : " red " , " confidence " : 0.44 } ,
{ " label " : " green " , " confidence " : 0.28 } ,
{ " label " : " blue " , " confidence " : 0.28 } ,
] ,
} ,
)
class TestHighlightedText ( unittest . TestCase ) :
2022-04-06 04:35:04 +08:00
def test_component_functions ( self ) :
2022-04-05 17:54:17 +08:00
"""
get_template_context , save_flagged , restore_flagged
"""
2022-03-30 04:23:30 +08:00
ht_output = gr . HighlightedText ( color_map = { " pos " : " green " , " neg " : " red " } )
self . assertEqual (
ht_output . get_template_context ( ) ,
{
" color_map " : { " pos " : " green " , " neg " : " red " } ,
" name " : " highlightedtext " ,
" label " : None ,
" show_legend " : False ,
" css " : { } ,
" default_value " : " " ,
2022-04-14 22:12:30 +08:00
" interactive " : None ,
2022-03-30 04:23:30 +08:00
} ,
)
ht = { " pos " : " Hello " , " neg " : " World " }
with tempfile . TemporaryDirectory ( ) as tmpdirname :
to_save = ht_output . save_flagged ( tmpdirname , " ht_output " , ht , None )
self . assertEqual ( to_save , ' { " pos " : " Hello " , " neg " : " World " } ' )
self . assertEqual (
ht_output . restore_flagged ( tmpdirname , to_save , None ) ,
{ " pos " : " Hello " , " neg " : " World " } ,
)
def test_in_interface ( self ) :
2022-04-05 17:54:17 +08:00
"""
Interface , process
"""
2022-03-30 04:23:30 +08:00
def highlight_vowels ( sentence ) :
phrases , cur_phrase = [ ] , " "
vowels , mode = " aeiou " , None
for letter in sentence :
letter_mode = " vowel " if letter in vowels else " non "
if mode is None :
mode = letter_mode
elif mode != letter_mode :
phrases . append ( ( cur_phrase , mode ) )
cur_phrase = " "
mode = letter_mode
cur_phrase + = letter
phrases . append ( ( cur_phrase , mode ) )
return phrases
iface = gr . Interface ( highlight_vowels , " text " , " highlight " )
self . assertListEqual (
2022-04-15 17:20:19 +08:00
iface . process ( [ " Helloooo " ] ) [ 0 ] ,
2022-03-30 04:23:30 +08:00
[ ( " H " , " non " ) , ( " e " , " vowel " ) , ( " ll " , " non " ) , ( " oooo " , " vowel " ) ] ,
)
class TestJSON ( unittest . TestCase ) :
2022-04-06 04:35:04 +08:00
def test_component_functions ( self ) :
2022-04-05 17:54:17 +08:00
"""
Postprocess , save_flagged , restore_flagged
"""
2022-03-30 04:23:30 +08:00
js_output = gr . JSON ( )
self . assertTrue (
js_output . postprocess ( ' { " a " :1, " b " : 2} ' ) , ' " { \\ " a \\ " :1, \\ " b \\ " : 2} " '
)
js = { " pos " : " Hello " , " neg " : " World " }
with tempfile . TemporaryDirectory ( ) as tmpdirname :
to_save = js_output . save_flagged ( tmpdirname , " js_output " , js , None )
self . assertEqual ( to_save , ' { " pos " : " Hello " , " neg " : " World " } ' )
self . assertEqual (
js_output . restore_flagged ( tmpdirname , to_save , None ) ,
{ " pos " : " Hello " , " neg " : " World " } ,
)
2022-04-05 17:54:17 +08:00
self . assertEqual (
js_output . get_template_context ( ) ,
2022-04-14 22:12:30 +08:00
{
" css " : { } ,
" default_value " : ' " " ' ,
" label " : None ,
" name " : " json " ,
" interactive " : None ,
} ,
2022-04-05 17:54:17 +08:00
)
2022-03-30 04:23:30 +08:00
def test_in_interface ( self ) :
2022-04-05 17:54:17 +08:00
"""
Interface , process
"""
2022-03-30 04:23:30 +08:00
def get_avg_age_per_gender ( data ) :
return {
" M " : int ( data [ data [ " gender " ] == " M " ] . mean ( ) ) ,
" F " : int ( data [ data [ " gender " ] == " F " ] . mean ( ) ) ,
" O " : int ( data [ data [ " gender " ] == " O " ] . mean ( ) ) ,
}
iface = gr . Interface (
get_avg_age_per_gender ,
gr . inputs . Dataframe ( headers = [ " gender " , " age " ] ) ,
" json " ,
)
y_data = [
[ " M " , 30 ] ,
[ " F " , 20 ] ,
[ " M " , 40 ] ,
[ " O " , 20 ] ,
[ " F " , 30 ] ,
]
2022-04-15 17:20:19 +08:00
self . assertDictEqual ( iface . process ( [ y_data ] ) [ 0 ] , { " M " : 35 , " F " : 25 , " O " : 20 } )
2022-03-30 04:23:30 +08:00
class TestHTML ( unittest . TestCase ) :
2022-04-06 04:35:04 +08:00
def test_component_functions ( self ) :
2022-04-05 17:54:17 +08:00
"""
Get_template_context
"""
html_component = gr . components . HTML ( " #Welcome onboard " , label = " HTML Input " )
self . assertEqual (
{
" css " : { } ,
" default_value " : " #Welcome onboard " ,
" label " : " HTML Input " ,
" name " : " html " ,
2022-04-14 22:12:30 +08:00
" interactive " : None ,
2022-04-05 17:54:17 +08:00
} ,
html_component . get_template_context ( ) ,
)
2022-03-30 04:23:30 +08:00
def test_in_interface ( self ) :
2022-04-05 17:54:17 +08:00
"""
Interface , process
"""
2022-03-30 04:23:30 +08:00
def bold_text ( text ) :
return " <strong> " + text + " </strong> "
iface = gr . Interface ( bold_text , " text " , " html " )
2022-04-15 17:20:19 +08:00
self . assertEqual ( iface . process ( [ " test " ] ) [ 0 ] , " <strong>test</strong> " )
2022-03-30 04:23:30 +08:00
class TestCarousel ( unittest . TestCase ) :
2022-04-06 04:35:04 +08:00
def test_component_functions ( self ) :
2022-04-05 17:54:17 +08:00
"""
Postprocess , get_template_context , save_flagged , restore_flagged
"""
2022-03-30 04:23:30 +08:00
carousel_output = gr . Carousel (
components = [ gr . Textbox ( ) , gr . Image ( ) ] , label = " Disease "
)
output = carousel_output . postprocess (
[
[ " Hello World " , " test/test_files/bus.png " ] ,
[ " Bye World " , " test/test_files/bus.png " ] ,
]
)
self . assertEqual (
output ,
[
2022-04-05 17:54:17 +08:00
[ " Hello World " , deepcopy ( media_data . BASE64_IMAGE ) ] ,
[ " Bye World " , deepcopy ( media_data . BASE64_IMAGE ) ] ,
2022-03-30 04:23:30 +08:00
] ,
)
carousel_output = gr . Carousel ( components = gr . Textbox ( ) , label = " Disease " )
output = carousel_output . postprocess ( [ [ " Hello World " ] , [ " Bye World " ] ] )
self . assertEqual ( output , [ [ " Hello World " ] , [ " Bye World " ] ] )
self . assertEqual (
carousel_output . get_template_context ( ) ,
{
" components " : [
{
" name " : " textbox " ,
" label " : None ,
" default_value " : " " ,
" lines " : 1 ,
" css " : { } ,
" placeholder " : None ,
2022-04-14 22:12:30 +08:00
" interactive " : None ,
2022-03-30 04:23:30 +08:00
}
] ,
" name " : " carousel " ,
" label " : " Disease " ,
" css " : { } ,
2022-04-14 22:12:30 +08:00
" interactive " : None ,
2022-03-30 04:23:30 +08:00
} ,
)
output = carousel_output . postprocess ( [ " Hello World " , " Bye World " ] )
self . assertEqual ( output , [ [ " Hello World " ] , [ " Bye World " ] ] )
with self . assertRaises ( ValueError ) :
carousel_output . postprocess ( " Hello World! " )
with tempfile . TemporaryDirectory ( ) as tmpdirname :
to_save = carousel_output . save_flagged (
tmpdirname , " carousel_output " , output , None
)
self . assertEqual ( to_save , ' [[ " Hello World " ], [ " Bye World " ]] ' )
2022-04-06 04:24:14 +08:00
restored = carousel_output . restore_flagged ( tmpdirname , to_save , None )
self . assertEqual ( output , restored )
2022-03-30 04:23:30 +08:00
def test_in_interface ( self ) :
2022-04-05 17:54:17 +08:00
"""
Interface , process
"""
2022-03-30 04:23:30 +08:00
carousel_output = gr . Carousel (
components = [ gr . Textbox ( ) , gr . Image ( ) ] , label = " Disease "
)
def report ( img ) :
results = [ ]
for i , mode in enumerate ( [ " Red " , " Green " , " Blue " ] ) :
color_filter = np . array ( [ 0 , 0 , 0 ] )
color_filter [ i ] = 1
results . append ( [ mode , img * color_filter ] )
return results
iface = gr . Interface ( report , gr . inputs . Image ( type = " numpy " ) , carousel_output )
2022-04-05 17:54:17 +08:00
result = iface . process ( [ deepcopy ( media_data . BASE64_IMAGE ) ] )
2022-04-15 17:20:19 +08:00
self . assertTrue ( result [ 0 ] [ 0 ] [ 0 ] == " Red " )
2022-03-30 14:24:10 +08:00
self . assertTrue (
2022-04-15 17:20:19 +08:00
result [ 0 ] [ 0 ] [ 1 ] . startswith ( "  " )
2022-03-30 04:23:30 +08:00
)
2022-04-15 17:20:19 +08:00
self . assertTrue ( result [ 0 ] [ 1 ] [ 0 ] == " Green " )
2022-03-30 14:24:10 +08:00
self . assertTrue (
2022-04-15 17:20:19 +08:00
result [ 0 ] [ 1 ] [ 1 ] . startswith ( "  " )
2022-03-30 04:23:30 +08:00
)
2022-04-15 17:20:19 +08:00
self . assertTrue ( result [ 0 ] [ 2 ] [ 0 ] == " Blue " )
2022-03-30 14:24:10 +08:00
self . assertTrue (
2022-04-15 17:20:19 +08:00
result [ 0 ] [ 2 ] [ 1 ] . startswith ( "  " )
2022-03-30 04:23:30 +08:00
)
2022-04-05 17:54:17 +08:00
if __name__ == " __main__ " :
unittest . main ( )