Fix bokeh Plots (#3212)

* Add implementation

* Simpler approach

* Fix extra updates

* Fix python tests

* CHANGELOG + add bokeh plot demo

* Center content

* Fix value=bokeh case

* Add image to changelog

* Add notebook file

* Undo accidental changes

* Add missing plot type

* Fix type hints

* Fix requirements

* Lint

* Fix requirements

* remove lorenz

* Fix notebooks

* Remove bokeh demo

* Fix plot

* Don't use beforeUpdate

* FIx demo: Add load_event + bump bokeh>3.0

* Only load bokeh if needed

* Fix tests

* lint

---------

Co-authored-by: Abubakar Abid <abubakar@huggingface.co>
This commit is contained in:
Freddy Boulton 2023-02-17 16:47:06 -05:00 committed by GitHub
parent a2b80ca5c6
commit 3530a86433
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 181 additions and 33 deletions

View File

@ -51,6 +51,14 @@ demo.launch()
By [@freddyaboulton](https://github.com/freddyaboulton) in [PR 3157](https://github.com/gradio-app/gradio/pull/3157)
### Bokeh plots are back! 🌠
Fixed a bug that prevented bokeh plots from being displayed on the front end and extended support for both 2.x and 3.x versions of bokeh!
![image](https://user-images.githubusercontent.com/41651716/219468324-0d82e07f-8fb4-4ff9-b40c-8250b29e45f7.png)
By [@freddyaboulton](https://github.com/freddyaboulton) in [PR 3212](https://github.com/gradio-app/gradio/pull/3212)
## Bug Fixes:
- Adds ability to add a single message from the bot or user side. Ex: specify `None` as the second value in the tuple, to add a single message in the chatbot from the "bot" side.

View File

@ -0,0 +1,2 @@
bokeh>=3.0
xyzservices

View File

@ -0,0 +1 @@
{"cells": [{"cell_type": "markdown", "id": 302934307671667531413257853548643485645, "metadata": {}, "source": ["# Gradio Demo: bokeh_plot"]}, {"cell_type": "code", "execution_count": null, "id": 272996653310673477252411125948039410165, "metadata": {}, "outputs": [], "source": ["!pip install -q gradio bokeh>=3.0 xyzservices"]}, {"cell_type": "code", "execution_count": null, "id": 288918539441861185822528903084949547379, "metadata": {}, "outputs": [], "source": ["import gradio as gr\n", "import xyzservices.providers as xyz\n", "from bokeh.plotting import figure\n", "from bokeh.tile_providers import get_provider\n", "from bokeh.models import ColumnDataSource, Whisker\n", "from bokeh.plotting import figure\n", "from bokeh.sampledata.autompg2 import autompg2 as df\n", "from bokeh.sampledata.penguins import data\n", "from bokeh.transform import factor_cmap, jitter, factor_mark\n", "\n", "\n", "def get_plot(plot_type):\n", " if plot_type == \"map\":\n", " tile_provider = get_provider(xyz.OpenStreetMap.Mapnik)\n", " plot = figure(\n", " x_range=(-2000000, 6000000),\n", " y_range=(-1000000, 7000000),\n", " x_axis_type=\"mercator\",\n", " y_axis_type=\"mercator\",\n", " )\n", " plot.add_tile(tile_provider)\n", " return plot\n", " elif plot_type == \"whisker\":\n", " classes = list(sorted(df[\"class\"].unique()))\n", "\n", " p = figure(\n", " height=400,\n", " x_range=classes,\n", " background_fill_color=\"#efefef\",\n", " title=\"Car class vs HWY mpg with quintile ranges\",\n", " )\n", " p.xgrid.grid_line_color = None\n", "\n", " g = df.groupby(\"class\")\n", " upper = g.hwy.quantile(0.80)\n", " lower = g.hwy.quantile(0.20)\n", " source = ColumnDataSource(data=dict(base=classes, upper=upper, lower=lower))\n", "\n", " error = Whisker(\n", " base=\"base\",\n", " upper=\"upper\",\n", " lower=\"lower\",\n", " source=source,\n", " level=\"annotation\",\n", " line_width=2,\n", " )\n", " error.upper_head.size = 20\n", " error.lower_head.size = 20\n", " p.add_layout(error)\n", "\n", " p.circle(\n", " jitter(\"class\", 0.3, range=p.x_range),\n", " \"hwy\",\n", " source=df,\n", " alpha=0.5,\n", " size=13,\n", " line_color=\"white\",\n", " color=factor_cmap(\"class\", \"Light6\", classes),\n", " )\n", " return p\n", " elif plot_type == \"scatter\":\n", "\n", " SPECIES = sorted(data.species.unique())\n", " MARKERS = [\"hex\", \"circle_x\", \"triangle\"]\n", "\n", " p = figure(title=\"Penguin size\", background_fill_color=\"#fafafa\")\n", " p.xaxis.axis_label = \"Flipper Length (mm)\"\n", " p.yaxis.axis_label = \"Body Mass (g)\"\n", "\n", " p.scatter(\n", " \"flipper_length_mm\",\n", " \"body_mass_g\",\n", " source=data,\n", " legend_group=\"species\",\n", " fill_alpha=0.4,\n", " size=12,\n", " marker=factor_mark(\"species\", MARKERS, SPECIES),\n", " color=factor_cmap(\"species\", \"Category10_3\", SPECIES),\n", " )\n", "\n", " p.legend.location = \"top_left\"\n", " p.legend.title = \"Species\"\n", " return p\n", "\n", "with gr.Blocks() as demo:\n", " with gr.Row():\n", " plot_type = gr.Radio(value=\"scatter\", choices=[\"scatter\", \"whisker\", \"map\"])\n", " plot = gr.Plot()\n", " plot_type.change(get_plot, inputs=[plot_type], outputs=[plot])\n", " demo.load(get_plot, inputs=[plot_type], outputs=[plot])\n", "\n", "\n", "if __name__ == \"__main__\":\n", " demo.launch()"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5}

94
demo/bokeh_plot/run.py Normal file
View File

@ -0,0 +1,94 @@
import gradio as gr
import xyzservices.providers as xyz
from bokeh.plotting import figure
from bokeh.tile_providers import get_provider
from bokeh.models import ColumnDataSource, Whisker
from bokeh.plotting import figure
from bokeh.sampledata.autompg2 import autompg2 as df
from bokeh.sampledata.penguins import data
from bokeh.transform import factor_cmap, jitter, factor_mark
def get_plot(plot_type):
if plot_type == "map":
tile_provider = get_provider(xyz.OpenStreetMap.Mapnik)
plot = figure(
x_range=(-2000000, 6000000),
y_range=(-1000000, 7000000),
x_axis_type="mercator",
y_axis_type="mercator",
)
plot.add_tile(tile_provider)
return plot
elif plot_type == "whisker":
classes = list(sorted(df["class"].unique()))
p = figure(
height=400,
x_range=classes,
background_fill_color="#efefef",
title="Car class vs HWY mpg with quintile ranges",
)
p.xgrid.grid_line_color = None
g = df.groupby("class")
upper = g.hwy.quantile(0.80)
lower = g.hwy.quantile(0.20)
source = ColumnDataSource(data=dict(base=classes, upper=upper, lower=lower))
error = Whisker(
base="base",
upper="upper",
lower="lower",
source=source,
level="annotation",
line_width=2,
)
error.upper_head.size = 20
error.lower_head.size = 20
p.add_layout(error)
p.circle(
jitter("class", 0.3, range=p.x_range),
"hwy",
source=df,
alpha=0.5,
size=13,
line_color="white",
color=factor_cmap("class", "Light6", classes),
)
return p
elif plot_type == "scatter":
SPECIES = sorted(data.species.unique())
MARKERS = ["hex", "circle_x", "triangle"]
p = figure(title="Penguin size", background_fill_color="#fafafa")
p.xaxis.axis_label = "Flipper Length (mm)"
p.yaxis.axis_label = "Body Mass (g)"
p.scatter(
"flipper_length_mm",
"body_mass_g",
source=data,
legend_group="species",
fill_alpha=0.4,
size=12,
marker=factor_mark("species", MARKERS, SPECIES),
color=factor_cmap("species", "Category10_3", SPECIES),
)
p.legend.location = "top_left"
p.legend.title = "Species"
return p
with gr.Blocks() as demo:
with gr.Row():
plot_type = gr.Radio(value="scatter", choices=["scatter", "whisker", "map"])
plot = gr.Plot()
plot_type.change(get_plot, inputs=[plot_type], outputs=[plot])
demo.load(get_plot, inputs=[plot_type], outputs=[plot])
if __name__ == "__main__":
demo.launch()

View File

@ -4132,7 +4132,17 @@ class Plot(Changeable, Clearable, IOComponent, JSONSerializable):
)
def get_config(self):
return {"value": self.value, **IOComponent.get_config(self)}
try:
import bokeh # type: ignore
bokeh_version = bokeh.__version__
except ImportError:
bokeh_version = None
return {
"value": self.value,
"bokeh_version": bokeh_version,
**IOComponent.get_config(self),
}
@staticmethod
def update(
@ -4162,9 +4172,11 @@ class Plot(Changeable, Clearable, IOComponent, JSONSerializable):
if isinstance(y, (ModuleType, matplotlib.figure.Figure)):
dtype = "matplotlib"
out_y = processing_utils.encode_plot_to_base64(y)
elif isinstance(y, dict):
elif "bokeh" in y.__module__:
dtype = "bokeh"
out_y = json.dumps(y)
from bokeh.embed import json_item # type: ignore
out_y = json.dumps(json_item(y))
else:
is_altair = "altair" in y.__module__
if is_altair:

View File

@ -2017,7 +2017,9 @@ simple = pd.DataFrame(
class TestScatterPlot:
@patch.dict("sys.modules", {"bokeh": MagicMock(__version__="3.0.3")})
def test_get_config(self):
assert gr.ScatterPlot().get_config() == {
"caption": None,
"elem_id": None,
@ -2029,6 +2031,7 @@ class TestScatterPlot:
"style": {},
"value": None,
"visible": True,
"bokeh_version": "3.0.3",
}
def test_no_color(self):
@ -2199,6 +2202,7 @@ class TestScatterPlot:
class TestLinePlot:
@patch.dict("sys.modules", {"bokeh": MagicMock(__version__="3.0.3")})
def test_get_config(self):
assert gr.LinePlot().get_config() == {
"caption": None,
@ -2211,6 +2215,7 @@ class TestLinePlot:
"style": {},
"value": None,
"visible": True,
"bokeh_version": "3.0.3",
}
def test_no_color(self):
@ -2360,6 +2365,7 @@ class TestLinePlot:
class TestBarPlot:
@patch.dict("sys.modules", {"bokeh": MagicMock(__version__="3.0.3")})
def test_get_config(self):
assert gr.BarPlot().get_config() == {
"caption": None,
@ -2372,6 +2378,7 @@ class TestBarPlot:
"style": {},
"value": None,
"visible": True,
"bokeh_version": "3.0.3",
}
def test_no_color(self):

View File

@ -20,6 +20,7 @@
export let style: Styles = {};
export let theme: string;
export let caption: string;
export let bokeh_version: string | null;
</script>
<Block
@ -32,5 +33,5 @@
<StatusTracker {...loading_status} />
<Plot {value} {target} {theme} {caption} on:change />
<Plot {value} {target} {theme} {caption} {bokeh_version} on:change />
</Block>

View File

@ -5,7 +5,7 @@
import { colors as color_palette, ordered_colors } from "@gradio/theme";
import { get_next_color } from "@gradio/utils";
import { Vega } from "svelte-vega";
import { afterUpdate, onDestroy } from "svelte";
import { afterUpdate, beforeUpdate, onDestroy } from "svelte";
import { create_config } from "./utils";
import { Empty } from "@gradio/atoms";
@ -15,6 +15,7 @@
export let colors: Array<string> = [];
export let theme: string;
export let caption: string;
export let bokeh_version: string | null;
function get_color(index: number) {
let current_color = colors[index % colors.length];
@ -32,8 +33,34 @@
$: darkmode = theme == "dark";
$: if (value && value.type == "altair") {
spec = JSON.parse(value["plot"]);
$: plot = value?.plot;
$: type = value?.type;
// Need to keep track of this because
// otherwise embed_bokeh may try to embed before
// bokeh is loaded
$: bokeh_loaded = window.Bokeh === undefined
function embed_bokeh(plot, type, bokeh_loaded){
if (document){
if (document.getElementById("bokehDiv")) {
document.getElementById("bokehDiv").innerHTML = "";
}
}
if (type == "bokeh" && window.Bokeh) {
if (!bokeh_loaded) {
load_bokeh();
bokeh_loaded = true;
}
let plotObj = JSON.parse(plot);
window.Bokeh.embed.embed_item(plotObj, "bokehDiv");
}
}
$: embed_bokeh(plot, type, bokeh_loaded);
$: if (type == "altair") {
spec = JSON.parse(plot);
const config = create_config(darkmode);
spec.config = config;
switch (value.chart || "") {
@ -75,13 +102,13 @@
let plotDiv;
let plotlyGlobalStyle;
const main_src = "https://cdn.bokeh.org/bokeh/release/bokeh-2.4.2.min.js";
const main_src = `https://cdn.bokeh.org/bokeh/release/bokeh-${bokeh_version}.min.js`;
const plugins_src = [
"https://cdn.pydata.org/bokeh/release/bokeh-widgets-2.4.2.min.js",
"https://cdn.pydata.org/bokeh/release/bokeh-tables-2.4.2.min.js",
"https://cdn.pydata.org/bokeh/release/bokeh-gl-2.4.2.min.js",
"https://cdn.pydata.org/bokeh/release/bokeh-api-2.4.2.min.js"
`https://cdn.pydata.org/bokeh/release/bokeh-widgets-${bokeh_version}.min.js`,
`https://cdn.pydata.org/bokeh/release/bokeh-tables-${bokeh_version}.min.js`,
`https://cdn.pydata.org/bokeh/release/bokeh-gl-${bokeh_version}.min.js`,
`https://cdn.pydata.org/bokeh/release/bokeh-api-${bokeh_version}.min.js`
];
function load_plugins() {
@ -100,7 +127,7 @@
script.onload = handleBokehLoaded;
script.src = main_src;
document.head.appendChild(script);
bokeh_loaded = true;
return script;
}
@ -115,10 +142,9 @@
}
}
const main_script = load_bokeh();
const main_script = bokeh_version ? load_bokeh() : null
let plugin_scripts = [];
// Bokeh
const resolves = [];
const bokehPromises = Array(5)
@ -126,7 +152,7 @@
.map((_, i) => createPromise(i));
const initializeBokeh = (index) => {
if (value && value["type"] == "bokeh") {
if (type == "bokeh") {
resolves[index]();
}
};
@ -142,23 +168,14 @@
plugin_scripts = load_plugins();
}
Promise.all(bokehPromises).then(() => {
let plotObj = JSON.parse(value["plot"]);
window.Bokeh.embed.embed_item(plotObj, "bokehDiv");
});
afterUpdate(() => {
if (value && value["type"] == "plotly") {
if (type == "plotly") {
load_plotly_css();
let plotObj = JSON.parse(value["plot"]);
let plotObj = JSON.parse(plot);
plotObj.layout.title
? (plotObj.layout.margin = { autoexpand: true })
: (plotObj.layout.margin = { l: 0, r: 0, b: 0, t: 0 });
Plotly.react(plotDiv, plotObj);
} else if (value && value["type"] == "bokeh") {
document.getElementById("bokehDiv").innerHTML = "";
let plotObj = JSON.parse(value["plot"]);
window.Bokeh.embed.embed_item(plotObj, "bokehDiv");
}
});
@ -170,11 +187,11 @@
});
</script>
{#if value && value["type"] == "plotly"}
{#if value && type == "plotly"}
<div bind:this={plotDiv} />
{:else if value && value["type"] == "bokeh"}
<div id="bokehDiv" />
{:else if value && value["type"] == "altair"}
{:else if type == "bokeh"}
<div id="bokehDiv" class="gradio-bokeh"/>
{:else if type == "altair"}
<div class="altair layout">
<Vega {spec} />
{#if caption}
@ -183,16 +200,22 @@
</div>
{/if}
</div>
{:else if value && value["type"] == "matplotlib"}
{:else if type == "matplotlib"}
<div class="matplotlib layout">
<!-- svelte-ignore a11y-missing-attribute -->
<img src={value["plot"]} />
<img src={plot} />
</div>
{:else}
<Empty size="large" unpadded_box={true}><PlotIcon /></Empty>
{/if}
<style>
.gradio-bokeh {
display: flex;
justify-content: center;
}
.layout {
display: flex;
flex-direction: column;