mirror of
https://github.com/gradio-app/gradio.git
synced 2024-11-21 01:01:05 +08:00
Fix bokeh Plots (#3212)
* Add implementation * Simpler approach * Fix extra updates * Fix python tests * CHANGELOG + add bokeh plot demo * Center content * Fix value=bokeh case * Add image to changelog * Add notebook file * Undo accidental changes * Add missing plot type * Fix type hints * Fix requirements * Lint * Fix requirements * remove lorenz * Fix notebooks * Remove bokeh demo * Fix plot * Don't use beforeUpdate * FIx demo: Add load_event + bump bokeh>3.0 * Only load bokeh if needed * Fix tests * lint --------- Co-authored-by: Abubakar Abid <abubakar@huggingface.co>
This commit is contained in:
parent
a2b80ca5c6
commit
3530a86433
@ -51,6 +51,14 @@ demo.launch()
|
||||
|
||||
By [@freddyaboulton](https://github.com/freddyaboulton) in [PR 3157](https://github.com/gradio-app/gradio/pull/3157)
|
||||
|
||||
### Bokeh plots are back! 🌠
|
||||
|
||||
Fixed a bug that prevented bokeh plots from being displayed on the front end and extended support for both 2.x and 3.x versions of bokeh!
|
||||
|
||||
![image](https://user-images.githubusercontent.com/41651716/219468324-0d82e07f-8fb4-4ff9-b40c-8250b29e45f7.png)
|
||||
|
||||
By [@freddyaboulton](https://github.com/freddyaboulton) in [PR 3212](https://github.com/gradio-app/gradio/pull/3212)
|
||||
|
||||
|
||||
## Bug Fixes:
|
||||
- Adds ability to add a single message from the bot or user side. Ex: specify `None` as the second value in the tuple, to add a single message in the chatbot from the "bot" side.
|
||||
|
2
demo/bokeh_plot/requirements.txt
Normal file
2
demo/bokeh_plot/requirements.txt
Normal file
@ -0,0 +1,2 @@
|
||||
bokeh>=3.0
|
||||
xyzservices
|
1
demo/bokeh_plot/run.ipynb
Normal file
1
demo/bokeh_plot/run.ipynb
Normal file
@ -0,0 +1 @@
|
||||
{"cells": [{"cell_type": "markdown", "id": 302934307671667531413257853548643485645, "metadata": {}, "source": ["# Gradio Demo: bokeh_plot"]}, {"cell_type": "code", "execution_count": null, "id": 272996653310673477252411125948039410165, "metadata": {}, "outputs": [], "source": ["!pip install -q gradio bokeh>=3.0 xyzservices"]}, {"cell_type": "code", "execution_count": null, "id": 288918539441861185822528903084949547379, "metadata": {}, "outputs": [], "source": ["import gradio as gr\n", "import xyzservices.providers as xyz\n", "from bokeh.plotting import figure\n", "from bokeh.tile_providers import get_provider\n", "from bokeh.models import ColumnDataSource, Whisker\n", "from bokeh.plotting import figure\n", "from bokeh.sampledata.autompg2 import autompg2 as df\n", "from bokeh.sampledata.penguins import data\n", "from bokeh.transform import factor_cmap, jitter, factor_mark\n", "\n", "\n", "def get_plot(plot_type):\n", " if plot_type == \"map\":\n", " tile_provider = get_provider(xyz.OpenStreetMap.Mapnik)\n", " plot = figure(\n", " x_range=(-2000000, 6000000),\n", " y_range=(-1000000, 7000000),\n", " x_axis_type=\"mercator\",\n", " y_axis_type=\"mercator\",\n", " )\n", " plot.add_tile(tile_provider)\n", " return plot\n", " elif plot_type == \"whisker\":\n", " classes = list(sorted(df[\"class\"].unique()))\n", "\n", " p = figure(\n", " height=400,\n", " x_range=classes,\n", " background_fill_color=\"#efefef\",\n", " title=\"Car class vs HWY mpg with quintile ranges\",\n", " )\n", " p.xgrid.grid_line_color = None\n", "\n", " g = df.groupby(\"class\")\n", " upper = g.hwy.quantile(0.80)\n", " lower = g.hwy.quantile(0.20)\n", " source = ColumnDataSource(data=dict(base=classes, upper=upper, lower=lower))\n", "\n", " error = Whisker(\n", " base=\"base\",\n", " upper=\"upper\",\n", " lower=\"lower\",\n", " source=source,\n", " level=\"annotation\",\n", " line_width=2,\n", " )\n", " error.upper_head.size = 20\n", " error.lower_head.size = 20\n", " p.add_layout(error)\n", "\n", " p.circle(\n", " jitter(\"class\", 0.3, range=p.x_range),\n", " \"hwy\",\n", " source=df,\n", " alpha=0.5,\n", " size=13,\n", " line_color=\"white\",\n", " color=factor_cmap(\"class\", \"Light6\", classes),\n", " )\n", " return p\n", " elif plot_type == \"scatter\":\n", "\n", " SPECIES = sorted(data.species.unique())\n", " MARKERS = [\"hex\", \"circle_x\", \"triangle\"]\n", "\n", " p = figure(title=\"Penguin size\", background_fill_color=\"#fafafa\")\n", " p.xaxis.axis_label = \"Flipper Length (mm)\"\n", " p.yaxis.axis_label = \"Body Mass (g)\"\n", "\n", " p.scatter(\n", " \"flipper_length_mm\",\n", " \"body_mass_g\",\n", " source=data,\n", " legend_group=\"species\",\n", " fill_alpha=0.4,\n", " size=12,\n", " marker=factor_mark(\"species\", MARKERS, SPECIES),\n", " color=factor_cmap(\"species\", \"Category10_3\", SPECIES),\n", " )\n", "\n", " p.legend.location = \"top_left\"\n", " p.legend.title = \"Species\"\n", " return p\n", "\n", "with gr.Blocks() as demo:\n", " with gr.Row():\n", " plot_type = gr.Radio(value=\"scatter\", choices=[\"scatter\", \"whisker\", \"map\"])\n", " plot = gr.Plot()\n", " plot_type.change(get_plot, inputs=[plot_type], outputs=[plot])\n", " demo.load(get_plot, inputs=[plot_type], outputs=[plot])\n", "\n", "\n", "if __name__ == \"__main__\":\n", " demo.launch()"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5}
|
94
demo/bokeh_plot/run.py
Normal file
94
demo/bokeh_plot/run.py
Normal file
@ -0,0 +1,94 @@
|
||||
import gradio as gr
|
||||
import xyzservices.providers as xyz
|
||||
from bokeh.plotting import figure
|
||||
from bokeh.tile_providers import get_provider
|
||||
from bokeh.models import ColumnDataSource, Whisker
|
||||
from bokeh.plotting import figure
|
||||
from bokeh.sampledata.autompg2 import autompg2 as df
|
||||
from bokeh.sampledata.penguins import data
|
||||
from bokeh.transform import factor_cmap, jitter, factor_mark
|
||||
|
||||
|
||||
def get_plot(plot_type):
|
||||
if plot_type == "map":
|
||||
tile_provider = get_provider(xyz.OpenStreetMap.Mapnik)
|
||||
plot = figure(
|
||||
x_range=(-2000000, 6000000),
|
||||
y_range=(-1000000, 7000000),
|
||||
x_axis_type="mercator",
|
||||
y_axis_type="mercator",
|
||||
)
|
||||
plot.add_tile(tile_provider)
|
||||
return plot
|
||||
elif plot_type == "whisker":
|
||||
classes = list(sorted(df["class"].unique()))
|
||||
|
||||
p = figure(
|
||||
height=400,
|
||||
x_range=classes,
|
||||
background_fill_color="#efefef",
|
||||
title="Car class vs HWY mpg with quintile ranges",
|
||||
)
|
||||
p.xgrid.grid_line_color = None
|
||||
|
||||
g = df.groupby("class")
|
||||
upper = g.hwy.quantile(0.80)
|
||||
lower = g.hwy.quantile(0.20)
|
||||
source = ColumnDataSource(data=dict(base=classes, upper=upper, lower=lower))
|
||||
|
||||
error = Whisker(
|
||||
base="base",
|
||||
upper="upper",
|
||||
lower="lower",
|
||||
source=source,
|
||||
level="annotation",
|
||||
line_width=2,
|
||||
)
|
||||
error.upper_head.size = 20
|
||||
error.lower_head.size = 20
|
||||
p.add_layout(error)
|
||||
|
||||
p.circle(
|
||||
jitter("class", 0.3, range=p.x_range),
|
||||
"hwy",
|
||||
source=df,
|
||||
alpha=0.5,
|
||||
size=13,
|
||||
line_color="white",
|
||||
color=factor_cmap("class", "Light6", classes),
|
||||
)
|
||||
return p
|
||||
elif plot_type == "scatter":
|
||||
|
||||
SPECIES = sorted(data.species.unique())
|
||||
MARKERS = ["hex", "circle_x", "triangle"]
|
||||
|
||||
p = figure(title="Penguin size", background_fill_color="#fafafa")
|
||||
p.xaxis.axis_label = "Flipper Length (mm)"
|
||||
p.yaxis.axis_label = "Body Mass (g)"
|
||||
|
||||
p.scatter(
|
||||
"flipper_length_mm",
|
||||
"body_mass_g",
|
||||
source=data,
|
||||
legend_group="species",
|
||||
fill_alpha=0.4,
|
||||
size=12,
|
||||
marker=factor_mark("species", MARKERS, SPECIES),
|
||||
color=factor_cmap("species", "Category10_3", SPECIES),
|
||||
)
|
||||
|
||||
p.legend.location = "top_left"
|
||||
p.legend.title = "Species"
|
||||
return p
|
||||
|
||||
with gr.Blocks() as demo:
|
||||
with gr.Row():
|
||||
plot_type = gr.Radio(value="scatter", choices=["scatter", "whisker", "map"])
|
||||
plot = gr.Plot()
|
||||
plot_type.change(get_plot, inputs=[plot_type], outputs=[plot])
|
||||
demo.load(get_plot, inputs=[plot_type], outputs=[plot])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
demo.launch()
|
@ -4132,7 +4132,17 @@ class Plot(Changeable, Clearable, IOComponent, JSONSerializable):
|
||||
)
|
||||
|
||||
def get_config(self):
|
||||
return {"value": self.value, **IOComponent.get_config(self)}
|
||||
try:
|
||||
import bokeh # type: ignore
|
||||
|
||||
bokeh_version = bokeh.__version__
|
||||
except ImportError:
|
||||
bokeh_version = None
|
||||
return {
|
||||
"value": self.value,
|
||||
"bokeh_version": bokeh_version,
|
||||
**IOComponent.get_config(self),
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def update(
|
||||
@ -4162,9 +4172,11 @@ class Plot(Changeable, Clearable, IOComponent, JSONSerializable):
|
||||
if isinstance(y, (ModuleType, matplotlib.figure.Figure)):
|
||||
dtype = "matplotlib"
|
||||
out_y = processing_utils.encode_plot_to_base64(y)
|
||||
elif isinstance(y, dict):
|
||||
elif "bokeh" in y.__module__:
|
||||
dtype = "bokeh"
|
||||
out_y = json.dumps(y)
|
||||
from bokeh.embed import json_item # type: ignore
|
||||
|
||||
out_y = json.dumps(json_item(y))
|
||||
else:
|
||||
is_altair = "altair" in y.__module__
|
||||
if is_altair:
|
||||
|
@ -2017,7 +2017,9 @@ simple = pd.DataFrame(
|
||||
|
||||
|
||||
class TestScatterPlot:
|
||||
@patch.dict("sys.modules", {"bokeh": MagicMock(__version__="3.0.3")})
|
||||
def test_get_config(self):
|
||||
|
||||
assert gr.ScatterPlot().get_config() == {
|
||||
"caption": None,
|
||||
"elem_id": None,
|
||||
@ -2029,6 +2031,7 @@ class TestScatterPlot:
|
||||
"style": {},
|
||||
"value": None,
|
||||
"visible": True,
|
||||
"bokeh_version": "3.0.3",
|
||||
}
|
||||
|
||||
def test_no_color(self):
|
||||
@ -2199,6 +2202,7 @@ class TestScatterPlot:
|
||||
|
||||
|
||||
class TestLinePlot:
|
||||
@patch.dict("sys.modules", {"bokeh": MagicMock(__version__="3.0.3")})
|
||||
def test_get_config(self):
|
||||
assert gr.LinePlot().get_config() == {
|
||||
"caption": None,
|
||||
@ -2211,6 +2215,7 @@ class TestLinePlot:
|
||||
"style": {},
|
||||
"value": None,
|
||||
"visible": True,
|
||||
"bokeh_version": "3.0.3",
|
||||
}
|
||||
|
||||
def test_no_color(self):
|
||||
@ -2360,6 +2365,7 @@ class TestLinePlot:
|
||||
|
||||
|
||||
class TestBarPlot:
|
||||
@patch.dict("sys.modules", {"bokeh": MagicMock(__version__="3.0.3")})
|
||||
def test_get_config(self):
|
||||
assert gr.BarPlot().get_config() == {
|
||||
"caption": None,
|
||||
@ -2372,6 +2378,7 @@ class TestBarPlot:
|
||||
"style": {},
|
||||
"value": None,
|
||||
"visible": True,
|
||||
"bokeh_version": "3.0.3",
|
||||
}
|
||||
|
||||
def test_no_color(self):
|
||||
|
@ -20,6 +20,7 @@
|
||||
export let style: Styles = {};
|
||||
export let theme: string;
|
||||
export let caption: string;
|
||||
export let bokeh_version: string | null;
|
||||
</script>
|
||||
|
||||
<Block
|
||||
@ -32,5 +33,5 @@
|
||||
|
||||
<StatusTracker {...loading_status} />
|
||||
|
||||
<Plot {value} {target} {theme} {caption} on:change />
|
||||
<Plot {value} {target} {theme} {caption} {bokeh_version} on:change />
|
||||
</Block>
|
||||
|
@ -5,7 +5,7 @@
|
||||
import { colors as color_palette, ordered_colors } from "@gradio/theme";
|
||||
import { get_next_color } from "@gradio/utils";
|
||||
import { Vega } from "svelte-vega";
|
||||
import { afterUpdate, onDestroy } from "svelte";
|
||||
import { afterUpdate, beforeUpdate, onDestroy } from "svelte";
|
||||
import { create_config } from "./utils";
|
||||
import { Empty } from "@gradio/atoms";
|
||||
|
||||
@ -15,6 +15,7 @@
|
||||
export let colors: Array<string> = [];
|
||||
export let theme: string;
|
||||
export let caption: string;
|
||||
export let bokeh_version: string | null;
|
||||
|
||||
function get_color(index: number) {
|
||||
let current_color = colors[index % colors.length];
|
||||
@ -32,8 +33,34 @@
|
||||
|
||||
$: darkmode = theme == "dark";
|
||||
|
||||
$: if (value && value.type == "altair") {
|
||||
spec = JSON.parse(value["plot"]);
|
||||
$: plot = value?.plot;
|
||||
$: type = value?.type;
|
||||
|
||||
// Need to keep track of this because
|
||||
// otherwise embed_bokeh may try to embed before
|
||||
// bokeh is loaded
|
||||
$: bokeh_loaded = window.Bokeh === undefined
|
||||
|
||||
function embed_bokeh(plot, type, bokeh_loaded){
|
||||
if (document){
|
||||
if (document.getElementById("bokehDiv")) {
|
||||
document.getElementById("bokehDiv").innerHTML = "";
|
||||
}
|
||||
}
|
||||
if (type == "bokeh" && window.Bokeh) {
|
||||
if (!bokeh_loaded) {
|
||||
load_bokeh();
|
||||
bokeh_loaded = true;
|
||||
}
|
||||
let plotObj = JSON.parse(plot);
|
||||
window.Bokeh.embed.embed_item(plotObj, "bokehDiv");
|
||||
}
|
||||
}
|
||||
|
||||
$: embed_bokeh(plot, type, bokeh_loaded);
|
||||
|
||||
$: if (type == "altair") {
|
||||
spec = JSON.parse(plot);
|
||||
const config = create_config(darkmode);
|
||||
spec.config = config;
|
||||
switch (value.chart || "") {
|
||||
@ -75,13 +102,13 @@
|
||||
let plotDiv;
|
||||
let plotlyGlobalStyle;
|
||||
|
||||
const main_src = "https://cdn.bokeh.org/bokeh/release/bokeh-2.4.2.min.js";
|
||||
const main_src = `https://cdn.bokeh.org/bokeh/release/bokeh-${bokeh_version}.min.js`;
|
||||
|
||||
const plugins_src = [
|
||||
"https://cdn.pydata.org/bokeh/release/bokeh-widgets-2.4.2.min.js",
|
||||
"https://cdn.pydata.org/bokeh/release/bokeh-tables-2.4.2.min.js",
|
||||
"https://cdn.pydata.org/bokeh/release/bokeh-gl-2.4.2.min.js",
|
||||
"https://cdn.pydata.org/bokeh/release/bokeh-api-2.4.2.min.js"
|
||||
`https://cdn.pydata.org/bokeh/release/bokeh-widgets-${bokeh_version}.min.js`,
|
||||
`https://cdn.pydata.org/bokeh/release/bokeh-tables-${bokeh_version}.min.js`,
|
||||
`https://cdn.pydata.org/bokeh/release/bokeh-gl-${bokeh_version}.min.js`,
|
||||
`https://cdn.pydata.org/bokeh/release/bokeh-api-${bokeh_version}.min.js`
|
||||
];
|
||||
|
||||
function load_plugins() {
|
||||
@ -100,7 +127,7 @@
|
||||
script.onload = handleBokehLoaded;
|
||||
script.src = main_src;
|
||||
document.head.appendChild(script);
|
||||
|
||||
bokeh_loaded = true;
|
||||
return script;
|
||||
}
|
||||
|
||||
@ -115,10 +142,9 @@
|
||||
}
|
||||
}
|
||||
|
||||
const main_script = load_bokeh();
|
||||
const main_script = bokeh_version ? load_bokeh() : null
|
||||
|
||||
let plugin_scripts = [];
|
||||
// Bokeh
|
||||
|
||||
const resolves = [];
|
||||
const bokehPromises = Array(5)
|
||||
@ -126,7 +152,7 @@
|
||||
.map((_, i) => createPromise(i));
|
||||
|
||||
const initializeBokeh = (index) => {
|
||||
if (value && value["type"] == "bokeh") {
|
||||
if (type == "bokeh") {
|
||||
resolves[index]();
|
||||
}
|
||||
};
|
||||
@ -142,23 +168,14 @@
|
||||
plugin_scripts = load_plugins();
|
||||
}
|
||||
|
||||
Promise.all(bokehPromises).then(() => {
|
||||
let plotObj = JSON.parse(value["plot"]);
|
||||
window.Bokeh.embed.embed_item(plotObj, "bokehDiv");
|
||||
});
|
||||
|
||||
afterUpdate(() => {
|
||||
if (value && value["type"] == "plotly") {
|
||||
if (type == "plotly") {
|
||||
load_plotly_css();
|
||||
let plotObj = JSON.parse(value["plot"]);
|
||||
let plotObj = JSON.parse(plot);
|
||||
plotObj.layout.title
|
||||
? (plotObj.layout.margin = { autoexpand: true })
|
||||
: (plotObj.layout.margin = { l: 0, r: 0, b: 0, t: 0 });
|
||||
Plotly.react(plotDiv, plotObj);
|
||||
} else if (value && value["type"] == "bokeh") {
|
||||
document.getElementById("bokehDiv").innerHTML = "";
|
||||
let plotObj = JSON.parse(value["plot"]);
|
||||
window.Bokeh.embed.embed_item(plotObj, "bokehDiv");
|
||||
}
|
||||
});
|
||||
|
||||
@ -170,11 +187,11 @@
|
||||
});
|
||||
</script>
|
||||
|
||||
{#if value && value["type"] == "plotly"}
|
||||
{#if value && type == "plotly"}
|
||||
<div bind:this={plotDiv} />
|
||||
{:else if value && value["type"] == "bokeh"}
|
||||
<div id="bokehDiv" />
|
||||
{:else if value && value["type"] == "altair"}
|
||||
{:else if type == "bokeh"}
|
||||
<div id="bokehDiv" class="gradio-bokeh"/>
|
||||
{:else if type == "altair"}
|
||||
<div class="altair layout">
|
||||
<Vega {spec} />
|
||||
{#if caption}
|
||||
@ -183,16 +200,22 @@
|
||||
</div>
|
||||
{/if}
|
||||
</div>
|
||||
{:else if value && value["type"] == "matplotlib"}
|
||||
{:else if type == "matplotlib"}
|
||||
<div class="matplotlib layout">
|
||||
<!-- svelte-ignore a11y-missing-attribute -->
|
||||
<img src={value["plot"]} />
|
||||
<img src={plot} />
|
||||
</div>
|
||||
{:else}
|
||||
<Empty size="large" unpadded_box={true}><PlotIcon /></Empty>
|
||||
{/if}
|
||||
|
||||
<style>
|
||||
|
||||
.gradio-bokeh {
|
||||
display: flex;
|
||||
justify-content: center;
|
||||
}
|
||||
|
||||
.layout {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
|
Loading…
Reference in New Issue
Block a user