diff --git a/demo/xray_blocks/run.py b/demo/xray_blocks/run.py index 915a91a0f2..79533c2389 100644 --- a/demo/xray_blocks/run.py +++ b/demo/xray_blocks/run.py @@ -1,35 +1,38 @@ import gradio as gr import random -xray_model = lambda diseases, img : {disease: random.random() for disease in diseases} -ct_model = lambda diseases, img : {disease: random.random() for disease in diseases} + +xray_model = lambda diseases, img: {disease: random.random() for disease in diseases} +ct_model = lambda diseases, img: {disease: random.random() for disease in diseases} xray_blocks = gr.Blocks() with xray_blocks: - gr.Markdown(""" + gr.Markdown(""" # Detect Disease From Scan With this model you can lorem ipsum - ipsum 1 - ipsum 2 """) - disease = gr.inputs.CheckboxGroup(["Covid", "Malaria", "Lung Cancer"], - label="Disease to Scan For") + disease = gr.inputs.CheckboxGroup(["Covid", "Malaria", "Lung Cancer"], + label="Disease to Scan For") - with gr.Tab("X-ray"): - with gr.Row(): - xray_scan = gr.inputs.Image() - xray_results = gr.outputs.JSON() - xray_run = gr.Button("Run") - xray_run.click(xray_model, inputs=[disease, xray_scan] , outputs=xray_results) + with gr.Tab("X-ray"): + with gr.Row(): + xray_scan = gr.inputs.Image() + xray_results = gr.outputs.JSON() + xray_run = gr.Button("Run") + xray_run.click(xray_model, inputs=[disease, xray_scan], outputs=xray_results) - with gr.Tab("CT Scan"): - with gr.Row(): - ct_scan = gr.inputs.Image() - ct_results = gr.outputs.JSON() - ct_run = gr.Button("Run") - ct_run.click(ct_model, inputs=[disease, ct_scan], outputs=ct_results) - - overall_probability = gr.outputs.Textbox() + with gr.Tab("CT Scan"): + with gr.Row(): + ct_scan = gr.inputs.Image() + ct_results = gr.outputs.JSON() + ct_run = gr.Button("Run") + ct_run.click(ct_model, inputs=[disease, ct_scan], outputs=ct_results) -xray_blocks.launch() \ No newline at end of file + overall_probability = gr.outputs.Textbox() + +#TODO: remove later +print(xray_blocks.get_config_file()) +xray_blocks.launch() diff --git a/gradio/__init__.py b/gradio/__init__.py index efadad3a8f..8a9b9b4ab7 100644 --- a/gradio/__init__.py +++ b/gradio/__init__.py @@ -1,16 +1,16 @@ import pkg_resources +from gradio.blocks import Blocks, Column, Row, Tab from gradio.flagging import ( CSVLogger, FlaggingCallback, HuggingFaceDatasetSaver, SimpleCSVLogger, ) -from gradio.blocks import Blocks, Tab, Row, Column from gradio.interface import Interface, close_all, reset_all from gradio.mix import Parallel, Series from gradio.routes import get_state, set_state -from gradio.static import Markdown, Button +from gradio.static import Button, Markdown current_pkg_version = pkg_resources.require("gradio")[0].version __version__ = current_pkg_version diff --git a/gradio/blocks.py b/gradio/blocks.py index 9fedd60463..0c5be57e96 100644 --- a/gradio/blocks.py +++ b/gradio/blocks.py @@ -1,6 +1,7 @@ -from gradio.launchable import Launchable +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple + from gradio import context, utils -from typing import TYPE_CHECKING, Any, Callable, List, Optional, Tuple, Dict +from gradio.launchable import Launchable class Block: @@ -75,8 +76,8 @@ class Blocks(Launchable, BlockContext): self.blocks = {} self.fns = [] self.dependencies = [] - - def process_api(self, data: Dict[str, Any], username: str=None) -> Dict[str, Any]: + + def process_api(self, data: Dict[str, Any], username: str = None) -> Dict[str, Any]: raw_input = data["data"] fn_index = data["fn_index"] fn = self.fns[fn_index] @@ -88,7 +89,7 @@ class Blocks(Launchable, BlockContext): ] predictions = fn(*processed_input) if len(dependency["outputs"]) == 1: - predictions = (predictions, ) + predictions = (predictions,) processed_output = [ self.blocks[output_id].postprocess(predictions[i]) if predictions[i] is not None @@ -96,14 +97,16 @@ class Blocks(Launchable, BlockContext): for i, output_id in enumerate(dependency["outputs"]) ] return {"data": processed_output} - + def get_template_context(self): return {"type": "column"} def get_config_file(self): + from gradio.component import Component + config = {"mode": "blocks", "components": [], "theme": self.theme} for _id, block in self.blocks.items(): - if not isinstance(block, BlockContext): + if isinstance(block, Component): config["components"].append( { "id": _id, @@ -111,6 +114,7 @@ class Blocks(Launchable, BlockContext): "props": block.get_template_context(), } ) + def getLayout(block_context): if not isinstance(block_context, BlockContext): return block_context._id @@ -128,10 +132,8 @@ class Blocks(Launchable, BlockContext): if len(running_tabs): children.append({"type": "tabset", "children": running_tabs}) running_tabs = [] - return { - "children": children, - **block_context.get_template_context() - } + return {"children": children, **block_context.get_template_context()} + config["layout"] = getLayout(self) config["dependencies"] = self.dependencies return config diff --git a/gradio/component.py b/gradio/component.py index 0671787d3b..c9fba82736 100644 --- a/gradio/component.py +++ b/gradio/component.py @@ -5,6 +5,7 @@ from typing import Any, Dict from gradio import processing_utils from gradio.blocks import Block + class Component(Block): """ A class for defining the methods that all gradio input and output components should have. diff --git a/gradio/context.py b/gradio/context.py index 15949b684c..fa15b9cb2e 100644 --- a/gradio/context.py +++ b/gradio/context.py @@ -1,3 +1,3 @@ root_block = None block = None -id = 0 \ No newline at end of file +id = 0 diff --git a/gradio/interface.py b/gradio/interface.py index 806c026201..3439dc5e7d 100644 --- a/gradio/interface.py +++ b/gradio/interface.py @@ -12,23 +12,23 @@ import re import time import warnings import weakref -from typing import TYPE_CHECKING, Any, Callable, List, Optional, Tuple, Dict -from gradio.routes import predict +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple from markdown_it import MarkdownIt from mdit_py_plugins.footnote import footnote_plugin from gradio import interpretation, utils -from gradio.launchable import Launchable from gradio.external import load_from_pipeline, load_interface # type: ignore from gradio.flagging import CSVLogger, FlaggingCallback # type: ignore from gradio.inputs import InputComponent from gradio.inputs import State as i_State # type: ignore from gradio.inputs import get_input_instance +from gradio.launchable import Launchable from gradio.outputs import OutputComponent from gradio.outputs import State as o_State # type: ignore from gradio.outputs import get_output_instance from gradio.process_examples import load_from_cache, process_example +from gradio.routes import predict if TYPE_CHECKING: # Only import for type checking (is False at runtime). import flask @@ -538,7 +538,7 @@ class Interface(Launchable): else: return predictions - def process_api(self, data: Dict[str, Any], username: str=None) -> Dict[str, Any]: + def process_api(self, data: Dict[str, Any], username: str = None) -> Dict[str, Any]: flag_index = None if data.get("example_id") is not None: example_id = data["example_id"] diff --git a/gradio/launchable.py b/gradio/launchable.py index a838976867..27fdc452b1 100644 --- a/gradio/launchable.py +++ b/gradio/launchable.py @@ -7,18 +7,18 @@ import time import webbrowser from typing import TYPE_CHECKING, Any, Callable, List, Optional, Tuple - -from gradio import networking, queueing # type: ignore -from gradio import encryptor, strings, utils +from gradio import encryptor, networking, queueing, strings, utils # type: ignore from gradio.process_examples import cache_interface_examples if TYPE_CHECKING: # Only import for type checking (is False at runtime). import flask + class Launchable: """ Gradio launchables can be launched to serve content to a port. """ + def launch( self, inline: bool = None, @@ -231,4 +231,3 @@ class Launchable: self.server.close() if self.enable_queue: queueing.close() - diff --git a/gradio/routes.py b/gradio/routes.py index 24fda4ac86..72a95fa392 100644 --- a/gradio/routes.py +++ b/gradio/routes.py @@ -195,9 +195,7 @@ def api_docs(request: Request): async def predict(request: Request, username: str = Depends(get_current_user)): body = await request.json() try: - output = await run_in_threadpool( - app.launchable.process_api, body, username - ) + output = await run_in_threadpool(app.launchable.process_api, body, username) except BaseException as error: if app.launchable.show_error: traceback.print_exc() diff --git a/gradio/static.py b/gradio/static.py index 7074db5956..ffe9feec58 100644 --- a/gradio/static.py +++ b/gradio/static.py @@ -2,6 +2,7 @@ from __future__ import annotations from gradio.component import Component + class StaticComponent(Component): def __init__(self, label: str): self.component_type = "static" @@ -17,5 +18,6 @@ class StaticComponent(Component): class Markdown(StaticComponent): pass + class Button(StaticComponent): - pass \ No newline at end of file + pass diff --git a/gradio/templates/frontend/index.html b/gradio/templates/frontend/index.html index 6b89a16e65..ae177678ca 100644 --- a/gradio/templates/frontend/index.html +++ b/gradio/templates/frontend/index.html @@ -45,10 +45,10 @@