From a1c391668a95539982c82b2c6c936f6590a5960e Mon Sep 17 00:00:00 2001 From: Abubakar Abid Date: Wed, 6 Jul 2022 11:23:35 -0700 Subject: [PATCH] Exposing examples as a component for Blocks (#1639) * examples as component * renamed examples * simplify internal logic * fix tests * cleanup * fixed parallel and series * cleaning up examples * examples * formatting * fixes * added unique ids * added demo * formatting * fixed test_examples * fixed test_interfaces * fixed tests * removed test from now * raise ValueError for bad parameter values * fixing series * fixed series * formatting * speed up by preprocessing examples * fixed parameter validation logic --- demo/blocks_inputs/lion.jpg | Bin 0 -> 18489 bytes demo/blocks_inputs/run.py | 37 +++- gradio/__init__.py | 1 + gradio/examples.py | 200 ++++++++++++++++++ gradio/flagging.py | 9 +- gradio/interface.py | 168 ++++----------- gradio/mix.py | 80 +++---- gradio/process_examples.py | 68 ------ ...t_process_examples.py => test_examples.py} | 8 +- test/test_external.py | 30 +-- test/test_interfaces.py | 4 +- 11 files changed, 327 insertions(+), 278 deletions(-) create mode 100644 demo/blocks_inputs/lion.jpg create mode 100644 gradio/examples.py delete mode 100644 gradio/process_examples.py rename test/{test_process_examples.py => test_examples.py} (74%) diff --git a/demo/blocks_inputs/lion.jpg b/demo/blocks_inputs/lion.jpg new file mode 100644 index 0000000000000000000000000000000000000000..e9bf9f5d0816d6201b4862088dc74476249a6a70 GIT binary patch literal 18489 zcmb69WmsHI&^3(C;0}YkySux4a0~A45D4xJ4DPPMgL@#jdmy+oNPwV$;GEpg^STkD+t>Q>2xkl00M^p! zm1nILt@8=M)VFB1fJz0c8Df`q3;#812+#p{M-7C#;+8PZ}|t(B6U@B@9BlWU>t0IoIUi>z(Ld+r|zCOB5FZ<%osw zhC&xQv<7}d?gH8D?&^+i%_{Al`gj}F z(OT}N#j2FKJwsCEEX|G)B|)xZeOL#>pZ-PC86?lD#edPStUh{LVMrj>5y#=+mHknj z%nF$MgnFqEryTg)5J!ti)-n5-*iuv;r^jZqt(lU?-f{k^)48cQeA0e2dgw!i{8@bK z5TX^2=;AayU8zUHbxU7!n)`3O0a9&1>*SBY6YjR`)TEy=?YeBefl;Mx{oH1)3IP>E z@ao5$bZQ)kISVlrGd^B@mqra*8oSWr0e%KiNwvW7sHao^n`{z$L zZAJj7IH@C%tr{wwqK8%Bdc zgAIY4jCZdqhE7^KVrdUFX-$}#v<<}#r#-8s6(ZZUJ*Ve3q*#s|qlt4THAWZNPIY@Z4@R(YPj-|^Y-?d zFTPIGsVUL8H!xxytCZ}vFkigd-~@bTv%4ML-vPk(!QmDo9#dB3)d?XI2l3e_^ke3L zgz(iETP^wia1Mh%Jnpu<>vIvxlRl5lhTa|3Emq8!1HJ0fHxmCc6f+h8=05|%!Td{O z|D~{h843;ng2%?8;=;nEmZIUt<3XULmHwBrkp87IAUw=dQpJ_U6Kv=~ptB3@>EgtJ zTMr^lt3aFc3Ae?W>4C9Dlt;8tyo1AM#Y!4y`9ZIcJ3@KPszIML#eNlgNb2L_4U7SG z*NZNz^VdG2S{s`Q0&|8o;u+hfu?~tPI`X*vr(wYl#X)C$cdE;r1Gy7PaslSn%SGhs zzgeBLE7!w4BSR+EBAaLwaP_{0@*CLuF5~=aYDqIwxk2Dqyw4swEt2L$gJpYGKtANc z7)|#q8+lxhOBLgO#X7ob zXL|a$4W*Ma6etWqZd4H+RpKThX=>KBC^Y#-MBjjt*-c!#m@evx_8eR4-4g~oPb@2_Vx27XSv`)ME*q9g`gxueATb1fn?&B1nd2gz*IA~+=K%@vu-c<5oWC~ zhvJj~@&e^T6BQ1^xD_GYpMi*YE$wcn2+r__Hi-%cG_oq0v>7L|EF*Ne3q6?-t{bUJ z(LhiAjC%cw1XgYZTmDKmkzhrCLR}lM*q72SJRBNamRQcQ#Bcb_F}O8yX!)rKRXh`g z#<|IG$Kf_)G1I7Nvt(rM8J#dL^`D5b5gMFq2fL8H^mj2Sn|T} zK6GthHURB)txDOGDy%+^B@ttNpsXA`TgrDi?tz0u2U%;QBg>mKm4{>U7V?D?+_MdO zejSaf<|rKMW;aCmyen!*$ip^1NM}p*Yp(#!d6*Q&&J4G+V4#+sMV_e zG)U?lkTj0qqRwj-v@P1e=y99%`QbdI>gX!)eth<}0(WqzjpD>xHU3oks*mw2v&`6p zC6QcQv(}#Nl^A{+`}&oPp@@Q!{OWzys#9O0I?bz%he97@ezriBoz$3MUwvNg;(+C?4`JjeSUvpMGnnbiXd> za+THmJ~l?4dtx3$?X<99`<2Qf9;$<9_ClmyZ<76S#es2M4Vv;Mf2fh^=L7_Owy%?)xEYOBq_uNr(B2m=)V=w1F-3>W3Sqy@$R_KQI1fb^i8m z8TY+Xi%$*yE)9G<-d|nk zXe(k1kq~WfIx&c>qrT&f z%H$VDF7y=r&j<@h`l)c5tTYc8SoPL57 z8IG>=9lZnSG@BLA#>qvP;qaj#?XwhZIN*wCQ2HVYR2k8P<|(WTx}s2RS)=<2d6_)vHlNNuE)pYkPSU zFH5PP+-s-br7`YOX3k>68rJQ{OdXKJ9a9yO5EDKvvBmi|7o06yAdVV$|X`-yA@=d_7R9*lUK~I z<*gDn+i9vNpjsP>H(~!K?}Pe6cz-3x&^{sC&veS+b-xHUKK`P+nm7o6(H_8Z-0<>Owqhu<_gU5r`=xX4`ceqrPC!(ZZ7>!IUi z1NH%lyEigZzqW5hgMR94*m4d)hh-B#eB+AAsBnB1{BreC_#8HPhsLTA=#y5XMnlSK zjXkMUPx6};tn0=aR+~ak>me#fKr^QedHF4>A+=;^HmpLDQkTsA(LMd+~(CK(xZ)m{w@bXr2vFks|&Mnb-+6 zSOmf8hodWokKrT9C?R|U-G?s!$_5S%GrziCG)Ss1T=`cpg{7JFND2$9(r)?sM}CYS zLHn4{qWW5Sv}Sv0OBh_@G0W;jh0gKAMVC(|S-!qpC@ku#NqSZ!tozem)34@l3+TSbBc{j1 z&)`*UZG%oJ^be=efc_V!{U4C_zd#LyMTJcX%Y`FF&8=ziKLG6?+mryFLp?c4X{+y> zXIZnrDc9{M4@zIKr$Ai3(d^gvpEI~U)=t2T6~T4MPx7KU97p1-MGNO8V>bCo&M7u~ zG@~EnoGYaChJT~pOc|v6_$$gSvM;bvu&{oXMP%hxO#j~PIlo<(R5np3^Yl^3{v%c# zZ7o(*5s{cM^`*g${9HmDZ&a(YLYy#*xZ(-Kxmj!enkY488Y*OoDnn$IJzh4&T9RQE zA%}JUStvDNtSlqHDi1#>MHj4KsV*X}ISn;+)=_y8%(vI%_~M=Dgp@*r=`RQRCF2;A z^Lbb^h=tWEvGFRV9*kM`%MY4llVbVHAmyKfX#dAS|FP7aWwY<;hql!!rElgu<7!yOB_q=-aNF zmH~p=n>ABinbL_@CH;;6!JLtB_YR0E=+jF4atFqY!YkB>p8Xt$J7m8n5BHVO`m`%)+!8op7honwFVKA)BdxlK{+*yr zgOh1k20l-(1ts)M!ddzV7kr4J*~VU5g0P?+BpSq*@_fK~&@Dh3lhb27vknx)1`(8L zNnn-9*)}!P6(kPozaen>$(AZ`U=C3iJeVmHnR2jzGOv2-Bv|nF3t2FK{0H`humC_< z7&tgsctp7W_0m6gfvC7(v86Pzs4YBjxWnM2Q%ag}DQP&hEC+Zb7p|?YJj3_?5B5ip zfO(l;)FH)0lGs|9%2&`0+|gapIbG5OGGYZC;=B%vEHT=2(-h6upPg2J1#!v!P?7<4 z5Hp`jw2R>iJK3B5lF-e`WaQdkzk+|Ww$N~3X!g=LI&?7Xcw0K;lGE^f1uf$g&(}RM z^DdTN?I@@(t2f9L!`%xKM_H0FpO=Ox6We+nj z;AmzfqiI|o^`i)K6(zMx^Y`LZJ7GOTotWh)>0^-iB9B3h7v%yo-GYRrNTrxgmc9m7 zT%nB{m$u2xp#fJshdxslgEXG-Lfo_klHP(wJN}$uE>5UP3}8Q-;I~@jn|{n0wWi z6O;XXb!rR4n194NWfLse^DA4+c(YG7V}qQN6x=+L(Yoro$%NvX$jH3GUz z61vHY2WIdet!>lEHvPSv5!sJ7l)CT>dSQyT=H3SCgQR7-eCp5a?DY6lSog^? z#^2`>O?U`6H9xiPi`DsI2pV`Zl-_(p51p3`g>mBStNZ|v@|9*X|K$pa+iV)-T z?5z1!P=R#DMC00ilgIiz^Z7`(b^!r)I+S2bZL7<8c+LipQa|}Pru|0H{BT%u9&Gyu zmB;g&p&@Ftmj!qJv1ZFeBgzy~T%`Jr>=-JpWzEMLW}oL&oPJ)eN1`nci=Tpv0t zwP3@9P2~cZ(plR1r!Bs0bVBE(Q^6pd0&wh$s)Cl1>@&V-$XF~b)qOs{MNDgRjI1O* ze1}~9?G{2Br)&Oom0TN}d*TQ}8iEGmR^#igqmP7qsLm{e&Gs@h8}a#z=o zBfg;_7+}Iz$7i9rz_&}Tn{BUQJyZc6SfqnneThlz@hm28PS6pWkE1Q!m~?VxVf>>e z&-Pk5VIZ_&L^^-Z^U^%mWsbpTN4!7xL>WyHOD$%JLp6PNy zsK!qn)by&lG81vsN?pd{C~XVW+(|zI^{dtWg}R2ZQMi+Nh-hk=YsW{&c2VFSRD)1I zYw09^R8WF!Obz!yNHGyt>rSOgl)0HgV8Qy)$BcISWJGN~rQf#Q;7;vTlg=-D=qifl zj&d45y3qiBiM>`Njg)m?VE}1&#Xi~tr%lgR0yXFQMaeComg?&eUGO(BBH=Dro7YxFlQ92Yh#%x{ z!(EMI)FL(#@PUE+&eD)}CA=--DLsvLte*NQG%`;Go6YwPUPE;=qhkWK;vK+Z)qrB6 z*dEsm^F{6_m7;)JiifusBF7N?pszL$zuiXNom6 zdr)Q@r+76wi)YsQ!$wwevu@Fco22H*JE?|-7gpUw2W`<@N@qm8L}!AZau7Hm17s+@ z_Bh|bAWCS{KhG2Wx>~mPlHt-*usx~EHtyrnLAx+Hzn z4CvTzPl$`=rPe*cLjQXb#y9LUAvg9#Pa2p2Vf5l6-s?>l`^|8#{ivCGXy;Pti>;NfqiW^4Sz;Q=UFtTIQR2ssPa52bB}VeQ?|^jT_2rWO z1+>e>&lzeOgl?1Mk{CQKLU6Wn2U!B(U4?ZD1-FDu*2$^qM*0C4K31=~6@eEnQ_ z9~|@C-@raOk+oo)#|`bD-ddt91R6jpm_>j?=?nE*M$BF`L_HdB-_b&i*0Irp>#>Kv z-|Z%mW;&G-(d?>;b6pX`&$~dMj|s07!YVp;r#u%YA26)(nex2oX-#OITIE)2orAp@ zwwfwnp%s^;@QG$$@3W>2zhUkN7yPX_p`23i)yh7M}Bf=uZVO*1eB z)<>G85gIZP{fy9$9g-DZSuuyiZMDw{bP=XAiJ z+m^73zPSLVQcf>yjeGNuQ&51Hm|amjZmn)nt}v_gRtooPoSQ7WQ?pl|+5l3&)d!6( zOL$txJjZPmn#T{2-X%nvfsH?zT_h);|r2cCbk|m z+$TUAC-6D@S3`9n8P&lVN;ei3@EuFuD>b3|ic$9`To#%JX3zYM(^}fUvNoXaH5@%P z=p3lF!_u@n&BERbd`$t9N1!!DsqqfH1w`Gw)|8mGFZy9RGScO$rR+NLuH{wZQE+kz zI`#rOVo%kYDD~;IrwL3wRk!F1Fzn7>DsrRH8JOm)n|tY;D^5b4zE|fjuGn16l(QR> z_?Z0zs0{z9(Lh*Oxc^_(4EPV9N@-engr#7WG+hlW>|Im-UnyFGDh$a}HKwG_G%}n8 z4XGTWo8YvfBf=nPsh86zP1;K)UQm9ujua#;#;XxY{niDJx5S*a{M5)yQlNA!c~uEM zwd106-W!t%R;KZEn0v8SSM=$V62W0rGP|=PX;8YO{8JXl_=)pnG!Xq6V0#9J{X)s>dINyVH|)8mxx zH8yYNJK6KO^pDXW`(h`SeO7NCI98YCCL<~K=541#9<*MkiH$4{!o}&BQbs~c!2-KS z`%9LBdr=@uRam?S*(qMyleF*V)*>JCyc1cS{fyna8a z&wNn%dlPgROvSiH+fRz-Q+yLXRh(+jo1F|Il#=}6r~J~1O_fURSXs5>nCa{Z{4&u& z77ITkKv}@8#B?qq$T6PglQUtk4|m}X-!BxyquXmV6eQiIto@B706!`ONH2K7qHozu zx65e@aMJ{ChyReR)y{Uu53itNv~U#``Ynt(!7iUoi_Cq-V5}D*i=Rkgor9PRDZsK; zO|o#Y*~R;5#fOv|YpqL`S29?jh+-5=r0gU=ibm2%Z+aD|V6F%b&HjbZc4H3TnK}vh z)X$r(NtadZV-+lZI1W9;P!AhU(YgvA#pbCpBm_-qFca8^n7$!mv{wubo+xq2DFpRK zswb2w66(6Cpp81PTDK>1|5M>2%_88TsK7F&b2k&hF{X3H@RCGJCJouxZ0|834pQ;_ zXS9=W6wOGy&A(Z{S5I&(7+HBk%c#!n>1RR=dF5d(-oDCq<+{->N@lN;{|$)J9**^! zlE)Pgbsk5FYMN`psW#Ykl)gah1)i}Lt2wyFNqV2q$km2(E+=wc5mdDL{dog^w3WVa zplA}iAvr*pB2oHb{giwfN=5!1F;)k2xMZg-B|d`SQ3BR2?17sOf(3o-zQ7k@6z;*U z|A6uRAoz|1%fAT7wk&|>o?fLoo&r?bonJKJW41>9^F!emZVrgQ&6v5hH@cylJ z-BjTKlXrBz-25XtB9E5-r`EA>0ucNw7{c!KEE1oK?437evD8hv(d^HX3;m#)`W^k+ zRaxhU;t6;c9BhR|--|dE&828@h$c0)U|3=u6!Ve+yT?n(mYfWJkZ<%0zx)%M&+$9X z6OXP+evF~C--I?}+B%GC;~fAnLjeo%HU3Z^=BMk7r`BMEOQ0JkCXWk^q%0p9yyN;a z&j%xXHpxQB>vrOzGO-9R{kKuxN{})Ki(Rt{DKRv1-3Fu3hKV|fE>jdqZF!^<-kU(n zh@kD;W*~GgZ-B@}dIao}Mh;|>X6VM>gYDm;(wWnoVy@N+`9M9jsD3;&S=?+`W3|yOyd7yFYx^KibCU*dXcwYs!xC$` zIrJiq+$NLO^;IT&{WC&6Y$Go^T|@)g7NtS?1l;Z9))`cm3~dV*wHOw!tV&n;lvZ8P z>St9i*)^~axOOBC-K(^q*Ji<9fLv9d=ID6%4PWFkwpnI|lr z&a6VEJJ7b=6Z-#j?Nb4?9pkEAx9v)Vr`Q;KhAOEx9o>Ra->VWSi~je1>}P z#{3+6V9wsO7V5ZQ0Te_lL?L5c>?4j>5J13LyNHtZ?EZvUOgt-0Dtv97Bi@oHB=kTP z+c4yr@}-Y%JKW>TdU86$?1YzpjG|9;D)9i`yVJe4T;>{9~R~Uv184i4ntdQA!*(!;(ayg07I=qeNewfdnq1$SVVUpD7 zE`Ibti^kp3(TqfHwl-M@rLrhGBrx}3B+iD?iPqb2y48(%T2twyZw}?*IW_S^6{}pd zZ2SwzrJtKvxz&^Z3y_~5F0mwk(n_j^RE6q#uN!@szU)jm+gcXEg}x2miDYamYsuIS zZElfGB*r#PsF^llZa(M=VR*)Yw#jI|X|vLU@6!d|MkE>Ecf@QdXhMrjDpjFuG@13L zOPAo+)<0yTc*W`R(bx|!ao9A%?c}AVhye!b3iYL`JdY7(m*k2c$O(fOsr9<3QF!3D zSFbwg17v%N$E1wcklZt6`$cE}lDlKdeTXn2L$ov*or_8iA6r4{%wY8tJWzrfZk|}G zdP*RAy;3Q1t@8XTS;9(pE=5l+fhTWh7UWFXl7)w1$foLwbRCzUY^+eF!q%7^(piV5 zN_|c?AkVGko^TJ|qh}JWE)hir=QY~>+)U+eNYN1px9O9Ve+s>tYw%DSZq77ab=>F}}oc%HeoM0Uk(F*&o=9OXl(@giRsB#R%I+ zdf&1cABn(0TIB-x3~{Sd!mStIHHDk~v}?asLM;BlV^1q`+#YFCAabn^hu0fRKNpiO zOCy~|Eed{G)whwCZ*sBVNWzj=>m4v3&&Z>(d_`1O;%FgUx)8-C0;SSM2}UAh64aYp z{l(5ym6gqu9BMcEND5-U?Y!6KuKcu9r>tUN6~2r>FYyP^ZkTjYTyhOl5nf7UAKPQA z5B1TSLR1`g*3V8q@^SLiWaen`j%)d9Z7d5`#5`3BJ0h*M=>!eTZAF8a2oRxKe@WbD ze5F$Y_X@%Rt=od#cykGF*la6{GMU6I}5St0o~xX$>*R8ho&ocD$Z4;&cbgu^Qjw%x`2uOM+AjHC~M z^7N{oa1Rs80sp<)>+;-wG$Cn&F6YUj+4KN6tDoSJhH7a+B1{WwqlWu;=fqd3x9|xE zsW|o{tC2s4{YB&0=J;U_vc@G}l#I|5sPM>@5Ky9trIA0tGE)x{lSqw}PI&aq!%`SU zZ0daMMw5%uu9T+N&-^lTVnDn=|gb~U=fAdv)e=)Ex_>6 zOi{)JSc`P680*3hDcGx47S_uR$bbQrR5+#yQwUQvH=4@w@oTWE$)!1*2cSHkvw&JM zBdWF5BQj4;KKgN21dIsu65Q zD%x65(IMcP$Q5w%JYrps&3p4wK3ZF`pd)d~?S`WbBbPqS%v;G~W0%v73QjxFSgEn( zHq9y7#^KWSD6~i~*GDju73!aVr~RI?I)eNrW&UP_vUGfPSs3AB57Bwak&)uj`9s+c zZd-D$(z3GGBYvj+;EBei;tzV7G%$j24Dy=BhlkZLrda?x@z^6WiaY2Kl23#u6iN(b z0Ens|TW)bxKDvGc(e|0MwT;p}5z0?kq~i|i$j@g)?%BeGP{%4E%HvI`jL*~MD-q%< zE9zd8AJNKC(w>J|!MgjbsfXtvGLHy$gNtj|T;-x1#QJM@I2Cu}7f3$)hmGGv`;!mP zX*g9Z13WPiwo6_9+EK}xmvpncvg5vR;c z?ru3p{$e5THjgGG3r1HYyvKi~)ou}}zBq*(G zkqJLR7%w%>`f9u;J-xmv&1KmA9QD;~6=&#%TjC!p`Jbjfz<=cjEN{2h{VC|*zDMEc zt3!!#F7WW~8MI4{ack~~;xKCknGkoQ>)UM59|@dGZE>3bw(dl}re?Xn%OQJ33AgfA zH=Q+L?LWUKLN;au8b7{7X6;Pgt-N9O9{Yc(ZZ6x}sPN4*d%I|Sv9s6x#@UkV2V`O1 zMD4Ld-^N__{(=IzMPA8J3CH1dwakwxG}0Qk2yXw#7%AvqTjT%Q|1GEh{*f`Pe{Af37}$Tc zCahqPbHC#1rUd_8k>Kwf{vibZ!^%vI7mab>>8^X>hricp68!RK-rkw!3ZI%fmlx-q z=VdO~op>loto;4*Z1+t(pcU6Q!)i$$2;IJ`_dt%Uhwcb|!tH2q^dcW~Bjun}GFTFQpnN4^`x!+fLPRlhqrOpOHu$PK&9 zQNep~tKM2l$WiG6T3EXO81|H7|37e+FChH>7ysPR5tBuV1nOgyH2QNfd9n`vNh$;TQ2(BfHMp6)k|WV+ApegekS>-};xw@2yRnhlNIGH;UA+G{kr5z&`cS58no2Y!ZKO6d4o?vl`5eQKbgyvV6^N zIn|i>R{-hH#Ex2Nr$|l24zUx^%5qz?hVL7P#f69egvv7EOf>W!z40oDT`&GqRgp&d!v0PWyjo~bCi`8&l+P$+=S!QL zQxmox04hirC%|Z#Orci?q|zK@7%GIIeh5Q;mf?4|SWc@nY(2v|do*aHMP13mSaqBs zMpvnXU>aj%DTk_o|2v=w5NoR;rd||G8s-NsCJ$r~zqZH+au#Rp*B*?ra4A!9^^Lt= zG618#^pidu1G&U15~-U)<%Ie{?||vJI-*f5Z3tFk$OE>ap+9RSF9Y(~J0O%P)6@76 z62z_IKSOAnjlbQ@hLL8Ma*Wi@6^VMUcxiBJCq*+tKU?&dEXSCvT#3AanMktaWmk<( zKK+ zT}4_(0FRL}ZC>Kk2M~{!ecx4tbN_p5&>s&cJEip^(Lz_gii-$5h?qs7)x}{rOIA!y zSr@%46hgd9W3VOZnq0^f-R%&}_qZGqN`o}8=SRE)BEyX#vNXsf{Uj5H1Qlk6lyKKc zThP7SM(KK|n6;$AR9u-r`P%G=V{)aVHLAMd-bg*d>uRh5Rf^#6P543Bn2n#@NOzOw zt#_}}lC?;iW?O@2=}@F_E#e-NdK5~l%Q6)Iek-dBiAz)%$0FN#j91}R#wM_%u5I2w z>ZeoL59tYGA|#-;_!LR6=ER8Q((?i43WPh;yilnMpqKVM(6X^x2TIQdig;Q z{jy)TH7pP+#;>sYH9&ZR%iW0r{N_m|6@>sA=Gfti{93HrxI+I9kl>^sLwh>tdKf?u z(h>Wx<*|qgCrRTqs2M5@h%@gTq6FD?IAG?l0*i>|t{(m8e4=P{zp89M+v> z#CLVqG<@C6R;Ew5s%>~3B%OAnq=>!)+tAMSoGY~%6faC9KMOB6tBYw(6JsHHh53MP zZez%%lP6l7bNC6^z{tyjvRiN%HeO#LbWg`;2w9U+W;HL*gG@b!rDwXG$Ta!{ILF(v z<6Q7#Jvd0T8o5rXD+e^2thdN44@s-FMMb zM@8-NfgxE1`{-PTF6(c1u>vB-pD)p8RKm4a$niQneILgiaN-6uv;;@+HKS@0(W+c7 z(Rs7g(Xz++2WkfdMm??QJQcZT%)h|(RNPWWo08LPE31~$>f^ma?-g?_Fvhb*S7vqS zW`7}$>;Z0ZZehu@-N7QmMw%FQ4-Z2)R8)qY(;92DtkFwIv`sWs4ISZ&+~E%$p8jz~e*+wZ5N2!Dpkd;XfFCFp z*>yCiFGOB@G`^9OIUMHwl`U^AWcrVJykv!RTfo z;L40ExDNj@h3z4m$(g$?_GcWPGnjh&nu0z|LsZVRbG-VqVR*q@2t_`2^^$TB^E)6{ zV2Cg-&N=$>@A1FJYdQInGDt*-h4QZR=xnnb{Tb(3rW>VI@Krs!m7$wQA~UrYvSx?D zEWW0Pg1SNZIJFZt$^M7CaMV#7c_c^+$HN&Y#GR!(3)}KuLu}_^>d+_>?w8o-hyFXV zaLQQEPj_NnIOMX!NHa<{y|UjZ%UjAOX%Jk%F+MmL9r}q9!<1V-U1(yU`J74xE~6Oz z=l+v07{>#*i~J+WR3`l&8MoJ9L5rI3(t@1sLXhFVMllKf6ff2srC&!4m-?wC6D_nv zTeCS7!0-XiNk&slA37n~@!LrH>r@;OcRevsyvxC7Ya6t*_`z2969%?;xH(>F2|?tH zZoESJ!o&zucVnhus=xY>4Bm+Ive>B^IHm0D(WpGiVdah=ZXvTBh>JJAJwR!B5w@NvEns2nP^cxNr&^ z;)Q>ghWHiMQP8QGiS>P2u5D6q-gPv6EFv^$_oo=4Ujr+|oOHe+pFcYof{LBTW9qs? zSR)U)C4@ckKH&QQFy>UaL8 z#}mTsyxyfR4d@n{e)@YrO3flE>wbgX3%AW}CRYOE$ch`Zk6jL9FHN%P&~25HAqR*` z@24Fr+SbNON0Q8C9CUFr${bu*|D*D4WtNp|QfUEaIl}B|ns2vc(bVZ78;`MCFYoO5 zJReQ737MsAg7{%{ynNC`@TH+^5ceq7#sf)gT3v|F38dIJs)&$!bd?ZWIZIE9N;saL z;}(NBl;G&`;T-@QISPj$nf=*A6jzYitcRy{khUZKX_)eW;S)VbVI`PuGx z`vYUqqv)s8N$J)Z>6KXt+f1ou-*56@`TE8;CQyal`a=UmfhpYIt{bLy8n)u2oXcMN z%EP{u=-HTr!}sq8N9oS=7(y`kWR{UixNB_rCq-73N;FpR5S=$qYXZqzlYo3WXmJL; zDaBx?DfayK*eRqzHR1f2uuNNXBGe+Op-yhRU;_+7HaR+Ezy)#?HfmI9|5*kRy#tP9 z=}}b^?fE;^;&%Kzr9TY|5f5}~KcCU6P@++^<(ZC{`UYC^5tfYfZG@%*x@ixCs)9R{ zVT+{(%qAX2>9J)rL)Mag(%BF{3Pm$AkN;kI#?9&i$ZPL8{WeNAOk!M$&Ny=tKodsUg zZL|Jn@{h!6vbyj|U2PUcfb#JKkBW~iESla)Jck-08y(S{!E0dy{-xoaQ3RWs#Oo3N zLI!#%!mvt8#QdaL&OtrenJ=veOMk(v;(C5}5I2D)uzKvB#0F!a)3(O=h39o-AA(NI zriUO*WKJ*Pz4Au{)gPG}6@i@qS7j6vx25lB5H0)_c~3H=*E0;S=t-2zs^=*FHO6QRQPMSF@c_<3lAbgg%OixcTDM?om>Wjj6HgWTXj#j}wo2 zhyL=&UIx(XtOUmOT!CT1M4OGiyM#YC9Gdy&^W&leF^geBs;}z25rINIKor7c{+8rs z0{w95g{I{X6^96O94~L41&bmf+uSy}a`=%uoo`<11UkEkgpu^Bmu~xWV+N8Pu1!p3 z&kqBsIA~i#4|SSMGAi7HPes+oNPfkl+H<1>rMyTaI-ggn1EtED2snN}KvEE>;A`Rf zq(skHaU&@06W*?qm%{Rjoy`1j_bC(D6CU5LvF_32!)c+1YMG>)duU}YZ`^vB{W{iT z<(gnf_Mn;t{=ND3XQzDFdSdmDDSwXH1MG?B9nuy0P1%f9P%_}QvKJW7t;1mIn>8`& zbX1{ayJYrkjqd>CVpFYw4~J%TIC$||FP$5Q7A$>9TNXn0asuywM3oH`lJht5P31!? ze8S7Wvv|&v`U!yxeL+ybz!xu+7v!5g$IdVKA_2`}R*S>Mv;ZA&xIP*egAmjd?(9@ess|8r}?b|A{^(F)e$;SBy5F03j@)NW^|@#l=9B*OIA+Nl?YX><6-aBT3lWK*KOwFQB9 z9O9*)D|_nNkwlLjmFecUznTeh=BI`4fO_Ukx{B&RoO)VSy64~=nYMga0lieOhoNd8 zSTX*j{KCLRyF>FTV)0?67wa$zBskB!LHz54hK*y8XNG+dkvg>61A6>AnttA&Hxffa zGkE=R9AcyI|SR0gEwb=~&v$KD8LW%zF^hEOGn zzPq+P`^*JcUmcCBBXnATzoD_TtPXJc{S?BM&c?b)2yf_(?_ zfuHFpAdVnLg6#uT7<`boXg3QI9o0nh!=+y+G6{h%!%zqX6}2WVxlWjk*%()pjpMi; zM(%A2O?9L~w>wxLMfrloAK8H^oa-{?Gyyh!N}e|3FlN0IE(KhaCo?hYxdn3UUe2eB&2R*9xnYgaDQDmJ2c;P zThOkBT6ZXKt7{Lov?+EggT5lMzj7?vEdWREYh*Df0SJ_AF*Po9mk~jAJ z4#<(;(FlaQHY82-#G6hs{+I^qg)U4p&crh7dPUt7DVf%B9hr=phVOv3rx$_QIiS!L z5)nb`i$i%}A18VBH(bwD%T{Q(r`;ahWpsv)P8aFhgt`6ngHVa}*e4>fiPu+)+7h9Oe@1MB>xL+i{RuUi$)V=Re?QIcO8`vAtMW6F<4 zIV@nd>}~CDl{13XF_PBE1%-s{So%2e24Wii4<>J`2iwrB?7>vtXc6`%R2+jUluVJR z`pfFYpJK(5gYK(KIv(Ex_W*ZTu?na#V@qxu}EA}|U-VPkn-OZtubGl{|Ijj&XM9zJsX)hn4QRJzDn z^s4h!qZnq)JNMG9r#LoYcutr2L!C{uMQabe$UC4jgSi<=JVfA(9wllZL{KKvkj!nx z)ZJ9W-G5oYTc$kvttr!cc}JWX_IVhcbx0KyY#k>&WGxkcfiY-#BmxnzScTOp2BYVn za-D*(Y|@rR-mYQPLE-Whto3!^WAr*LBsn#xj%qzVN{3*Q_@9+XYee2F*n0bS09TGu zZ)}Z3(Le@D3ct^bU;s!9`(ulT7{vU;x+G`Z!O!z;_2J?sVrr9sfLo+y=F_v7Wr!Q=X5PefA4&gn@tT%J`=mdkmjA|^G{ zZ7Q6f%-GGhM>m!w2Tc%V>4plQ3G}W%Rqeph=X)oFQBIS1sfqet?QEG75kR~73%PN7 z+v@)l^acz0E!*PhX$-c^#~mq-@?cporN`9+vK^n<;vfq_Bj z1-Wb8T;DLbKuabkYH%)z6JnE?o7qmJz{Do%x~W6MXWVANtweD;tnENJX-uFZ;!%TH znTAk-j9DYYpo~j;(=*nmCZ3%jtkFD`1~ZzP)Lj05&zNEbR03n8d`#?!aR#jl>!^M@ z;~zt12O5Rc{{VR;ui6ll8t@Lk%!Ap|AbyKJ^_(Y2-oTAS6FqvueMDhEh%q-`Jk*$S zFxG4F1kXgV&M^M~0l^?5r2;2%ZzO;M)FE*{pUzpI4_}Yn<42v&hS>QWq0)(+-r@c~ z?g>t3z$|y?`y(lG3n-r5&zOjrdBQ&`m|q3LNT+|CYJqT(S()O<`TS&#oXXlBS)Kj! zjcSrXO6cK`-EIROkjP3HP^=dXm2rwo&dV{mw8!+~A;onQ8<&lKGry1+F51+5`pAV+ zBW%cM2XegJB&bTVzWn{&7)oM5iD0zLZ1V(;lhsfN(j+v*8|7mTP%4DYm=Q@NG&Q|B zX2}2#9o%9zJl#a%lFU*7)1zopbG_q4lKDxITeDpw1sLxlbf6d*bTaWeFeQ^jQ6n;o z6}Ow*VoCy0Skx(o0e?7}qs=()u4DSDh>{E>#VmPsRPVfzGGcIaPY0>+Zlo;{NhG5| zL?P;A<>9zhBYBIkPg9pDD!a{v*gGO-cZV7%w>0L##d`TlDwvsLC9s{**BM`u?mmBC z<_y`5T+KerwNObA&tuEt8^}pY2^OnuamGwJ09$T3UM}V^S5~A2m{>zGja*HOD2ggE zsPp6h05K#NiCwleCCiMF0LcedQ!q(tDS_rz2w1` zHs4n2c;LvS%ppnASNJ_%oKOwpw6hRB>B7gfhyw(}D2a(Yy24UK5mHAIU2oOn!cG=c z-klQB<<|)h6v&aaje&YrBTx+45GLqCI0TxV%{f$VO-#i?=(Te}7=CP|5n@R)Z>~pK zQ2d|_4(>na-U}v>w6WnoAI?DV!qMON>RGCC)07gYufH;Rd}Fh^fXQUeJlR=fi*!!E z1BF7ML;(Q&NArnKAiqBF)AYRzzk$n@QeC?+#{z$t+Dfq}dq&$%BU8KrhE!IB)OzFn z!|zC_Wk*Cmy3X{&Yz9StEb%yqLt3UUPM&KdZkR)NN-676jKWBiiFN#ap9$BjgdCM^ zp5GTnL?=;1$dc+}DkJJvSi7vtV8L-%GnZ79cU(;M{o)}E=tL)G_0F}6Ich=jiTCcX zQ0!_{lgV(-C|GQSo+5Sg_<5eI0t~VvCCnNx+nSCC!!yg_)7Ix1p&^*Dz)4BFF-fM7 zU2YkA{k$C)76XRkcRnLGu1X>?=1tF|0Ewp2Ato6LE;)FU3&eIQpA!HV964m1DCs05 z-oAHP629!C1Wv_5`tgB4S%@=|M*KhfkZXv7LvgEze~g?uY*?miS{8g_3T9#~NS*Wb zhNmWkNR^&A-!Iia3W|yU0F#$+MPc(2^Y_f)I`CAKMw`^gd}Tq|kvR59>&7*hGKJCz zKeTU{wGf#%=TR{~TEra{yCJUPS*n@j+OA`oi!bE=?#gr`qta(c(d)pXrI3oxrmLhIB;XGhrIaa%^v%jh16%jL4T>k*& z@$-QJ5ot&=O~V}gIeOVLkeWy-LU-pX*H{HXBj!koK(7o(@znxHxLtAM(?+UnOCW}l z8S~su^^7L$jS_uDe%>lh*=R_)HT-zS9^D2A#7c4NB1NGBN=CKc;1z~hp<69LSBO4x zL53n=yHS{y^h8evCyr)p)+)Ewzwk!|q|>R>w$f@189m+1D1+sW_kHDd>`1*Y5ps%R zlPtYnNclmba|fkCO- XW@0xO9xwB(t-OQ@=H None: + """Caches all of the examples from an interface.""" + if os.path.exists(self.cached_file): + print( + f"Using cache from '{os.path.abspath(self.cached_folder)}' directory. If method or examples have changed since last caching, delete this folder to clear cache." + ) + else: + print(f"Caching examples at: '{os.path.abspath(self.cached_file)}'") + cache_logger = CSVLogger() + cache_logger.setup(self.outputs, self.cached_folder) + for example_id, _ in enumerate(self.examples): + try: + prediction = self.process_example(example_id) + cache_logger.flag(prediction) + except Exception as e: + shutil.rmtree(self.cached_folder) + raise e + + def process_example(self, example_id: int) -> Tuple[List[Any], List[float]]: + """Loads an example from the interface and returns its prediction.""" + example_set = self.examples[example_id] + raw_input = [ + self.inputs[i].preprocess_example(example) + for i, example in enumerate(example_set) + ] + processed_input = [ + input_component.preprocess(raw_input[i]) + for i, input_component in enumerate(self.inputs) + ] + predictions = self.fn(*processed_input) + if len(self.outputs) == 1: + predictions = [predictions] + processed_output = [ + output_component.postprocess(predictions[i]) + if predictions[i] is not None + else None + for i, output_component in enumerate(self.outputs) + ] + + return processed_output + + def load_from_cache(self, example_id: int) -> List[Any]: + """Loads a particular cached example for the interface.""" + with open(self.cached_file) as cache: + examples = list(csv.reader(cache, quotechar="'")) + example = examples[example_id + 1] # +1 to adjust for header + output = [] + for component, cell in zip(self.outputs, example): + output.append( + component.restore_flagged( + self.cached_folder, + cell, + None, + ) + ) + return output diff --git a/gradio/flagging.py b/gradio/flagging.py index 2ec380dd22..63a55bc967 100644 --- a/gradio/flagging.py +++ b/gradio/flagging.py @@ -125,11 +125,11 @@ class CSVLogger(FlaggingCallback): if flag_index is None: csv_data = [] - for component, sample in zip(self.components, flag_data): + for idx, (component, sample) in enumerate(zip(self.components, flag_data)): csv_data.append( component.save_flagged( flagging_dir, - component.label, + component.label or f"component {idx}", sample, self.encryption_key, ) @@ -140,7 +140,10 @@ class CSVLogger(FlaggingCallback): csv_data.append(username if username is not None else "") csv_data.append(str(datetime.datetime.now())) if is_new: - headers = [component.label for component in self.components] + [ + headers = [ + component.label or f"component {idx}" + for idx, component in enumerate(self.components) + ] + [ "flag", "username", "timestamp", diff --git a/gradio/interface.py b/gradio/interface.py index cad838d82f..65dabc5512 100644 --- a/gradio/interface.py +++ b/gradio/interface.py @@ -25,7 +25,6 @@ from gradio.blocks import Blocks from gradio.components import ( Button, Component, - Dataset, Interpretation, IOComponent, Markdown, @@ -34,10 +33,10 @@ from gradio.components import ( get_component_instance, ) from gradio.events import Changeable, Streamable +from gradio.examples import Examples from gradio.external import load_from_pipeline # type: ignore from gradio.flagging import CSVLogger, FlaggingCallback # type: ignore from gradio.layouts import Column, Row, TabItem, Tabs -from gradio.process_examples import cache_interface_examples, load_from_cache if TYPE_CHECKING: # Only import for type checking (is False at runtime). import transformers @@ -126,7 +125,6 @@ class Interface(Blocks): flagging_dir: str = "flagged", flagging_callback: FlaggingCallback = CSVLogger(), analytics_enabled: Optional[bool] = None, - _repeat_outputs_per_model: bool = True, **kwargs, ): """ @@ -174,14 +172,13 @@ class Interface(Blocks): inputs = [] self.interface_type = self.InterfaceTypes.OUTPUT_ONLY - if not isinstance(fn, list): - fn = [fn] - else: + if isinstance(fn, list): raise DeprecationWarning( "The `fn` parameter only accepts a single function, support for a list " "of functions has been deprecated. Please use gradio.mix.Parallel " "instead." ) + if not isinstance(inputs, list): inputs = [inputs] if not isinstance(outputs, list): @@ -199,7 +196,7 @@ class Interface(Blocks): raise ValueError( "If using 'state', there must be exactly one state input and one state output." ) - default = utils.get_default_args(fn[0])[inputs.index("state")] + default = utils.get_default_args(fn)[inputs.index("state")] state_variable = Variable(value=default) inputs[inputs.index("state")] = state_variable outputs[outputs.index("state")] = state_variable @@ -240,9 +237,6 @@ class Interface(Blocks): for o in self.output_components: o.interactive = False # Force output components to be non-interactive - if _repeat_outputs_per_model: - self.output_components *= len(fn) - if ( interpretation is None or isinstance(interpretation, list) @@ -257,10 +251,9 @@ class Interface(Blocks): raise ValueError("Invalid value for parameter: interpretation") self.api_mode = False - self.predict = fn - self.predict_durations = [[0, 0]] * len(fn) - self.function_names = [func.__name__ for func in fn] - self.__name__ = ", ".join(self.function_names) + self.fn = fn + self.fn_durations = [0, 0] + self.__name__ = fn.__name__ self.live = live self.title = title @@ -295,53 +288,7 @@ class Interface(Blocks): if not (self.theme == "default"): warnings.warn("Currently, only the 'default' theme is supported.") - if examples is None or ( - isinstance(examples, list) - and (len(examples) == 0 or isinstance(examples[0], list)) - ): - self.examples = examples - elif ( - isinstance(examples, list) and len(self.input_components) == 1 - ): # If there is only one input component, examples can be provided as a regular list instead of a list of lists - self.examples = [[e] for e in examples] - elif isinstance(examples, str): - if not os.path.exists(examples): - raise FileNotFoundError( - "Could not find examples directory: " + examples - ) - log_file = os.path.join(examples, "log.csv") - if not os.path.exists(log_file): - if len(self.input_components) == 1: - exampleset = [ - [os.path.join(examples, item)] for item in os.listdir(examples) - ] - else: - raise FileNotFoundError( - "Could not find log file (required for multiple inputs): " - + log_file - ) - else: - with open(log_file) as logs: - exampleset = list(csv.reader(logs)) - exampleset = exampleset[1:] # remove header - for i, example in enumerate(exampleset): - for j, (component, cell) in enumerate( - zip( - self.input_components + self.output_components, - example, - ) - ): - exampleset[i][j] = component.restore_flagged( - examples, - cell, - None, - ) - self.examples = exampleset - else: - raise ValueError( - "Examples argument must either be a directory or a nested " - "list, where each sublist represents a set of inputs." - ) + self.examples = examples self.num_shap = num_shap self.examples_per_page = examples_per_page @@ -415,7 +362,7 @@ class Interface(Blocks): utils.version_check() Interface.instances.add(self) - param_names = inspect.getfullargspec(self.predict[0])[0] + param_names = inspect.getfullargspec(self.fn)[0] for component, param_name in zip(self.input_components, param_names): if component.label is None: component.label = param_name @@ -426,9 +373,6 @@ class Interface(Blocks): else: component.label = "output " + str(i) - if self.cache_examples and examples: - cache_interface_examples(self) - if self.allow_flagging != "never": if self.interface_type == self.InterfaceTypes.UNIFIED: self.flagging_callback.setup(self.input_components, self.flagging_dir) @@ -625,34 +569,16 @@ class Interface(Blocks): non_state_inputs = [ c for c in self.input_components if not isinstance(c, Variable) ] - - examples = Dataset( - components=non_state_inputs, - samples=self.examples, - type="index", - ) - - def load_example(example_id): - processed_examples = [ - component.preprocess_example(sample) - for component, sample in zip( - self.input_components, self.examples[example_id] - ) - ] - if self.cache_examples: - processed_examples += load_from_cache(self, example_id) - if len(processed_examples) == 1: - return processed_examples[0] - else: - return processed_examples - - examples.click( - load_example, - inputs=[examples], - outputs=non_state_inputs - + (self.output_components if self.cache_examples else []), - _postprocess=False, - queue=False, + non_state_outputs = [ + c for c in self.output_components if not isinstance(c, Variable) + ] + self.examples_handler = Examples( + examples=examples, + inputs=non_state_inputs, + outputs=non_state_outputs, + fn=self.fn, + cache_examples=self.cache_examples, + examples_per_page=examples_per_page, ) if self.interpretation: @@ -684,9 +610,7 @@ class Interface(Blocks): return self.__repr__() def __repr__(self): - repr = "Gradio Interface for: {}".format( - ", ".join(fn.__name__ for fn in self.predict) - ) + repr = f"Gradio Interface for: {self.__name__}" repr += "\n" + "-" * len(repr) repr += "\ninputs:" for component in self.input_components: @@ -715,31 +639,19 @@ class Interface(Blocks): input_component.serialize(processed_input[i], called_directly) for i, input_component in enumerate(self.input_components) ] - predictions = [] - output_component_counter = 0 - for predict_fn in self.predict: - prediction = predict_fn(*processed_input) + prediction = self.fn(*processed_input) - if len(self.output_components) == len(self.predict) or prediction is None: - prediction = [prediction] + if prediction is None or len(self.output_components) == 1: + prediction = [prediction] - if self.api_mode: # Serialize the input - prediction_ = copy.deepcopy(prediction) - prediction = [] + if self.api_mode: # Deerialize the input + prediction = [ + output_component.deserialize(prediction[i]) + for i, output_component in enumerate(self.output_components) + ] - # Done this way to handle both single interfaces with multiple outputs and Parallel() interfaces - for pred in prediction_: - prediction.append( - self.output_components[output_component_counter].deserialize( - pred - ) - ) - output_component_counter += 1 - - predictions.extend(prediction) - - return predictions + return prediction def process(self, raw_input: List[Any]) -> Tuple[List[Any], List[float]]: """ @@ -777,19 +689,17 @@ class Interface(Blocks): Passes a few samples through the function to test if the inputs/outputs components are consistent with the function parameter and return values. """ - for predict_fn in self.predict: - print("Test launch: {}()...".format(predict_fn.__name__), end=" ") - raw_input = [] - for input_component in self.input_components: - if input_component.test_input is None: - print("SKIPPED") - break - else: - raw_input.append(input_component.test_input) + print("Test launch: {}()...".format(self.__name__), end=" ") + raw_input = [] + for input_component in self.input_components: + if input_component.test_input is None: + print("SKIPPED") + break else: - self.process(raw_input) - print("PASSED") - continue + raw_input.append(input_component.test_input) + else: + self.process(raw_input) + print("PASSED") def integrate(self, comet_ml=None, wandb=None, mlflow=None) -> None: """ diff --git a/gradio/mix.py b/gradio/mix.py index 9df8580fbe..63db9b3672 100644 --- a/gradio/mix.py +++ b/gradio/mix.py @@ -1,10 +1,13 @@ """ Ways to transform interfaces to produce new interfaces """ -import warnings +from typing import TYPE_CHECKING, List import gradio +if TYPE_CHECKING: # Only import for type checking (to avoid circular imports). + from gradio.components import IOComponent + class Parallel(gradio.Interface): """ @@ -12,7 +15,7 @@ class Parallel(gradio.Interface): The Interfaces to put in Parallel must share the same input components (but can have different output components). """ - def __init__(self, *interfaces, **options): + def __init__(self, *interfaces: gradio.Interface, **options): """ Parameters: *interfaces (Interface): any number of Interface objects that are to be compared in parallel @@ -20,38 +23,29 @@ class Parallel(gradio.Interface): Returns: (Interface): an Interface object comparing the given models """ - fns = [] - outputs = [] + outputs: List[IOComponent] = [] - for io in interfaces: - if not (isinstance(io, gradio.Interface)): - warnings.warn( - "Parallel may not work properly with non-Interface objects." - ) - fns.extend(io.predict) - outputs.extend(io.output_components) + for interface in interfaces: + outputs.extend(interface.output_components) def parallel_fn(*args): return_values = [] - for fn in fns: - value = fn(*args) - if isinstance(value, tuple): - return_values.extend(value) - else: - return_values.append(value) + for interface in interfaces: + value = interface.run_prediction(args) + return_values.extend(value) + if len(outputs) == 1: + return return_values[0] return return_values + parallel_fn.__name__ = " | ".join([io.__name__ for io in interfaces]) + kwargs = { "fn": parallel_fn, "inputs": interfaces[0].input_components, "outputs": outputs, - "_repeat_outputs_per_model": False, } kwargs.update(options) super().__init__(**kwargs) - self.api_mode = interfaces[ - 0 - ].api_mode # TODO(abidlabs): make api_mode a per-function attribute class Series(gradio.Interface): @@ -60,7 +54,7 @@ class Series(gradio.Interface): and so the input and output components must agree between the interfaces). """ - def __init__(self, *interfaces, **options): + def __init__(self, *interfaces: gradio.Interface, **options): """ Parameters: *interfaces (Interface): any number of Interface objects that are to be connected in series @@ -68,41 +62,35 @@ class Series(gradio.Interface): Returns: (Interface): an Interface object connecting the given models """ - fns = [] - for io in interfaces: - if not (isinstance(io, gradio.Interface)): - warnings.warn( - "Series may not work properly with non-Interface objects." - ) - fns.append(io.predict) - def connected_fn( - *data, - ): # Run each function with the appropriate preprocessing and postprocessing - for idx, io in enumerate(interfaces): + def connected_fn(*data): + for idx, interface in enumerate(interfaces): # skip preprocessing for first interface since the Series interface will include it - if idx > 0 and not (io.api_mode): + if idx > 0 and not (interface.api_mode): data = [ input_component.preprocess(data[i]) - for i, input_component in enumerate(io.input_components) + for i, input_component in enumerate(interface.input_components) ] # run all of predictions sequentially - predictions = [] - for predict_fn in io.predict: - prediction = predict_fn(*data) - predictions.append(prediction) - data = predictions + data = interface.fn(*data) + if len(interface.output_components) == 1: + data = [data] + # skip postprocessing for final interface since the Series interface will include it - if idx < len(interfaces) - 1 and not (io.api_mode): + if idx < len(interfaces) - 1 and not (interface.api_mode): data = [ output_component.postprocess(data[i]) - for i, output_component in enumerate(io.output_components) + for i, output_component in enumerate( + interface.output_components + ) ] - return data[0] + if len(interface.output_components) == 1: + return data[0] + return data - connected_fn.__name__ = " => ".join([f[0].__name__ for f in fns]) + connected_fn.__name__ = " => ".join([io.__name__ for io in interfaces]) kwargs = { "fn": connected_fn, @@ -111,6 +99,4 @@ class Series(gradio.Interface): } kwargs.update(options) super().__init__(**kwargs) - self.api_mode = interfaces[ - 0 - ].api_mode # TODO(abidlabs): make api_mode a per-function attribute + self.api_mode = interfaces[0].api_mode # TODO: set api_mode per-function diff --git a/gradio/process_examples.py b/gradio/process_examples.py deleted file mode 100644 index 8e6a374378..0000000000 --- a/gradio/process_examples.py +++ /dev/null @@ -1,68 +0,0 @@ -""" -Defines helper methods useful for loading and caching Interface examples. -""" -from __future__ import annotations - -import csv -import os -import shutil -from typing import TYPE_CHECKING, Any, List, Tuple - -from gradio.flagging import CSVLogger - -if TYPE_CHECKING: # Only import for type checking (to avoid circular imports). - from gradio import Interface - -CACHED_FOLDER = "gradio_cached_examples" -CACHE_FILE = os.path.join(CACHED_FOLDER, "log.csv") - - -def process_example( - interface: Interface, example_id: int -) -> Tuple[List[Any], List[float]]: - """Loads an example from the interface and returns its prediction.""" - example_set = interface.examples[example_id] - raw_input = [ - interface.input_components[i].preprocess_example(example) - for i, example in enumerate(example_set) - ] - prediction = interface.process(raw_input) - return prediction - - -def cache_interface_examples(interface: Interface) -> None: - """Caches all of the examples from an interface.""" - if os.path.exists(CACHE_FILE): - print( - f"Using cache from '{os.path.abspath(CACHED_FOLDER)}/' directory. If method or examples have changed since last caching, delete this folder to clear cache." - ) - else: - print( - f"Cache at {os.path.abspath(CACHE_FILE)} not found. Caching now in '{CACHED_FOLDER}/' directory." - ) - cache_logger = CSVLogger() - cache_logger.setup(interface.output_components, CACHED_FOLDER) - for example_id, _ in enumerate(interface.examples): - try: - prediction = process_example(interface, example_id) - cache_logger.flag(prediction) - except Exception as e: - shutil.rmtree(CACHED_FOLDER) - raise e - - -def load_from_cache(interface: Interface, example_id: int) -> List[Any]: - """Loads a particular cached example for the interface.""" - with open(CACHE_FILE) as cache: - examples = list(csv.reader(cache, quotechar="'")) - example = examples[example_id + 1] # +1 to adjust for header - output = [] - for component, cell in zip(interface.output_components, example): - output.append( - component.restore_flagged( - CACHED_FOLDER, - cell, - interface.encryption_key if interface.encrypt else None, - ) - ) - return output diff --git a/test/test_process_examples.py b/test/test_examples.py similarity index 74% rename from test/test_process_examples.py rename to test/test_examples.py index 48ebbcd79a..653143edd2 100644 --- a/test/test_process_examples.py +++ b/test/test_examples.py @@ -1,7 +1,7 @@ import os import unittest -from gradio import Interface, process_examples +from gradio import Interface, examples os.environ["GRADIO_ANALYTICS_ENABLED"] = "False" @@ -9,7 +9,7 @@ os.environ["GRADIO_ANALYTICS_ENABLED"] = "False" class TestProcessExamples(unittest.TestCase): def test_process_example(self): io = Interface(lambda x: "Hello " + x, "text", "text", examples=[["World"]]) - prediction = process_examples.process_example(io, 0) + prediction = io.examples_handler.process_example(0) self.assertEquals(prediction[0], "Hello World") def test_caching(self): @@ -20,8 +20,8 @@ class TestProcessExamples(unittest.TestCase): examples=[["World"], ["Dunya"], ["Monde"]], ) io.launch(prevent_thread_lock=True) - process_examples.cache_interface_examples(io) - prediction = process_examples.load_from_cache(io, 1) + io.examples_handler.cache_interface_examples() + prediction = io.examples_handler.load_from_cache(1) io.close() self.assertEquals(prediction[0], "Hello Dunya") diff --git a/test/test_external.py b/test/test_external.py index 408e2e8772..597e9984aa 100644 --- a/test/test_external.py +++ b/test/test_external.py @@ -27,7 +27,7 @@ class TestLoadInterface(unittest.TestCase): src="models", alias=model_type, ) - self.assertEqual(interface.predict[0].__name__, model_type) + self.assertEqual(interface.__name__, model_type) self.assertIsInstance(interface.input_components[0], gr.components.Audio) self.assertIsInstance(interface.output_components[0], gr.components.Audio) @@ -36,14 +36,14 @@ class TestLoadInterface(unittest.TestCase): interface = gr.Blocks.load( name="lysandre/tiny-vit-random", src="models", alias=model_type ) - self.assertEqual(interface.predict[0].__name__, model_type) + self.assertEqual(interface.__name__, model_type) self.assertIsInstance(interface.input_components[0], gr.components.Image) self.assertIsInstance(interface.output_components[0], gr.components.Label) def test_text_generation(self): model_type = "text_generation" interface = gr.Interface.load("models/gpt2", alias=model_type) - self.assertEqual(interface.predict[0].__name__, model_type) + self.assertEqual(interface.__name__, model_type) self.assertIsInstance(interface.input_components[0], gr.components.Textbox) self.assertIsInstance(interface.output_components[0], gr.components.Textbox) @@ -52,7 +52,7 @@ class TestLoadInterface(unittest.TestCase): interface = gr.Interface.load( "models/facebook/bart-large-cnn", api_key=None, alias=model_type ) - self.assertEqual(interface.predict[0].__name__, model_type) + self.assertEqual(interface.__name__, model_type) self.assertIsInstance(interface.input_components[0], gr.components.Textbox) self.assertIsInstance(interface.output_components[0], gr.components.Textbox) @@ -61,7 +61,7 @@ class TestLoadInterface(unittest.TestCase): interface = gr.Interface.load( "models/facebook/bart-large-cnn", api_key=None, alias=model_type ) - self.assertEqual(interface.predict[0].__name__, model_type) + self.assertEqual(interface.__name__, model_type) self.assertIsInstance(interface.input_components[0], gr.components.Textbox) self.assertIsInstance(interface.output_components[0], gr.components.Textbox) @@ -70,7 +70,7 @@ class TestLoadInterface(unittest.TestCase): interface = gr.Interface.load( "models/sshleifer/tiny-mbart", api_key=None, alias=model_type ) - self.assertEqual(interface.predict[0].__name__, model_type) + self.assertEqual(interface.__name__, model_type) self.assertIsInstance(interface.input_components[0], gr.components.Textbox) self.assertIsInstance(interface.output_components[0], gr.components.Textbox) @@ -81,7 +81,7 @@ class TestLoadInterface(unittest.TestCase): api_key=None, alias=model_type, ) - self.assertEqual(interface.predict[0].__name__, model_type) + self.assertEqual(interface.__name__, model_type) self.assertIsInstance(interface.input_components[0], gr.components.Textbox) self.assertIsInstance(interface.output_components[0], gr.components.Label) @@ -90,7 +90,7 @@ class TestLoadInterface(unittest.TestCase): interface = gr.Interface.load( "models/bert-base-uncased", api_key=None, alias=model_type ) - self.assertEqual(interface.predict[0].__name__, model_type) + self.assertEqual(interface.__name__, model_type) self.assertIsInstance(interface.input_components[0], gr.components.Textbox) self.assertIsInstance(interface.output_components[0], gr.components.Label) @@ -99,7 +99,7 @@ class TestLoadInterface(unittest.TestCase): interface = gr.Interface.load( "models/facebook/bart-large-mnli", api_key=None, alias=model_type ) - self.assertEqual(interface.predict[0].__name__, model_type) + self.assertEqual(interface.__name__, model_type) self.assertIsInstance(interface.input_components[0], gr.components.Textbox) self.assertIsInstance(interface.input_components[1], gr.components.Textbox) self.assertIsInstance(interface.input_components[2], gr.components.Checkbox) @@ -110,7 +110,7 @@ class TestLoadInterface(unittest.TestCase): interface = gr.Interface.load( "models/facebook/wav2vec2-base-960h", api_key=None, alias=model_type ) - self.assertEqual(interface.predict[0].__name__, model_type) + self.assertEqual(interface.__name__, model_type) self.assertIsInstance(interface.input_components[0], gr.components.Audio) self.assertIsInstance(interface.output_components[0], gr.components.Textbox) @@ -119,7 +119,7 @@ class TestLoadInterface(unittest.TestCase): interface = gr.Interface.load( "models/google/vit-base-patch16-224", api_key=None, alias=model_type ) - self.assertEqual(interface.predict[0].__name__, model_type) + self.assertEqual(interface.__name__, model_type) self.assertIsInstance(interface.input_components[0], gr.components.Image) self.assertIsInstance(interface.output_components[0], gr.components.Label) @@ -130,7 +130,7 @@ class TestLoadInterface(unittest.TestCase): api_key=None, alias=model_type, ) - self.assertEqual(interface.predict[0].__name__, model_type) + self.assertEqual(interface.__name__, model_type) self.assertIsInstance(interface.input_components[0], gr.components.Textbox) self.assertIsInstance(interface.output_components[0], gr.components.Dataframe) @@ -141,7 +141,7 @@ class TestLoadInterface(unittest.TestCase): api_key=None, alias=model_type, ) - self.assertEqual(interface.predict[0].__name__, model_type) + self.assertEqual(interface.__name__, model_type) self.assertIsInstance(interface.input_components[0], gr.components.Textbox) self.assertIsInstance(interface.output_components[0], gr.components.Audio) @@ -152,7 +152,7 @@ class TestLoadInterface(unittest.TestCase): api_key=None, alias=model_type, ) - self.assertEqual(interface.predict[0].__name__, model_type) + self.assertEqual(interface.__name__, model_type) self.assertIsInstance(interface.input_components[0], gr.components.Textbox) self.assertIsInstance(interface.output_components[0], gr.components.Audio) @@ -161,7 +161,7 @@ class TestLoadInterface(unittest.TestCase): interface = gr.Interface.load( "models/osanseviero/BigGAN-deep-128", api_key=None, alias=model_type ) - self.assertEqual(interface.predict[0].__name__, model_type) + self.assertEqual(interface.__name__, model_type) self.assertIsInstance(interface.input_components[0], gr.components.Textbox) self.assertIsInstance(interface.output_components[0], gr.components.Image) diff --git a/test/test_interfaces.py b/test/test_interfaces.py index 1871c90c73..8d5d4c3f80 100644 --- a/test/test_interfaces.py +++ b/test/test_interfaces.py @@ -120,8 +120,8 @@ class TestInterface(unittest.TestCase): examples = ["test1", "test2"] interface = Interface(lambda x: x, "textbox", "label", examples=examples) interface.launch(prevent_thread_lock=True) - self.assertEqual(len(interface.examples), 2) - self.assertEqual(len(interface.examples[0]), 1) + self.assertEqual(len(interface.examples_handler.examples), 2) + self.assertEqual(len(interface.examples_handler.examples[0]), 1) interface.close() @mock.patch("IPython.display.display")