gradio/ui/packages/app/test/outbreak_forecast.spec.ts

78 lines
11 KiB
TypeScript
Raw Normal View History

import { test, expect, Page } from "@playwright/test";
import { BASE64_PLOT_IMG } from "./media_data";
function mock_demo(page: Page, demo: string) {
return page.route("**/config", (route) => {
return route.fulfill({
headers: {
"Access-Control-Allow-Origin": "*"
},
path: `../../../demo/${demo}/config.json`
});
});
}
function mock_api(page: Page, body: Array<unknown>) {
return page.route("**/api/predict/", (route) => {
const id = JSON.parse(route.request().postData()!).fn_index;
return route.fulfill({
headers: {
"Access-Control-Allow-Origin": "*"
},
body: JSON.stringify({
data: body[id]
})
});
});
}
test("matplotlib", async ({ page }) => {
await mock_demo(page, "outbreak_forecast");
await mock_api(page, [[{ type: "matplotlib", plot: BASE64_PLOT_IMG }]]);
await page.goto("http://localhost:3000");
await page
.locator("text=Plot Type MatplotlibPlotlyBokeh >> select")
.selectOption("Matplotlib");
await page
.locator("text=Month JanuaryFebruaryMarchAprilMay >> select")
.selectOption("January");
await page.locator('label:has-text("Social Distancing?")').click();
const submit_button = await page.locator("text=Submit");
await Promise.all([
submit_button.click(),
page.waitForResponse("**/api/predict/")
]);
const matplotlib_img = await page.locator("img").nth(0);
const matplotlib_img_data = await matplotlib_img.getAttribute("src");
await expect(matplotlib_img_data).toEqual(BASE64_PLOT_IMG);
});
test("plotly", async ({ page }) => {
await mock_demo(page, "outbreak_forecast");
await mock_api(page, [
[
{
type: "plotly",
plot: '{"data":[{"hovertemplate":"variable=USA<br>day=%{x}<br>value=%{y}<extra></extra>","legendgroup":"USA","line":{"color":"#636efa","dash":"solid"},"marker":{"symbol":"circle"},"mode":"lines","name":"USA","orientation":"v","showlegend":true,"x":[0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30],"xaxis":"x","y":[0.0,351.0,1212.8467905971102,2504.9990397077145,4190.875605304019,6246.872925984764,8655.783606149129,11404.190112523678,14481.168170041885,17877.550395830687,21585.46945740342,25598.055992414193,29909.228965302143,34513.54453420909,39406.083696107235,44582.36661187451,50038.28586647387,55770.05352459305,61774.158465302986,68047.33152596047,74586.51668075123,81388.8469530371,88451.62409095351,95772.30127063072,103348.46826154085,111177.83861367931,119258.23851978898,127587.5970765873,136163.9377231475,144985.3706765514,154050.0862177657],"yaxis":"y","type":"scatter"},{"hovertemplate":"variable=Canada<br>day=%{x}<br>value=%{y}<extra></extra>","legendgroup":"Canada","line":{"color":"#EF553B","dash":"solid"},"marker":{"symbol":"circle"},"mode":"lines","name":"Canada","orientation":"v","showlegend":true,"x":[0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30],"xaxis":"x","y":[0.0,41.0,141.67156243442028,292.6067254359439,489.53247811243523,729.6917093030636,1011.0744383251119,1332.1133749671533,1691.5324643068868,2088.2608724474594,2521.379623229459,2990.086312504222,3493.6706198785982,4031.4966549930846,4602.989833448423,5207.626869193318,5844.9279787049245,6514.450696604887,7215.7848919584685,7948.548696764613,8712.38513934701,9506.959330696642,10331.95609039628,11187.07792619903,12072.043301205626,12986.585137210404,13930.449513707545,14903.394530313617,15905.189306692442,16935.61309897039,17994.45451546551],"yaxis":"y","type":"scatter"}],"layout":{"template":{"data":{"histogram2dcontour":[{"type":"histogram2dcontour","colorbar":{"outlinewidth":0,"ticks":""},"colorscale":[[0.0,"#0d0887"],[0.1111111111111111,"#46039f"],[0.2222222222222222,"#7201a8"],[0.3333333333333333,"#9c179e"],[0.4444444444444444,"#bd3786"],[0.5555555555555556,"#d8576b"],[0.6666666666666666,"#ed7953"],[0.7777777777777778,"#fb9f3a"],[0.8888888888888888,"#fdca26"],[1.0,"#f0f921"]]}],"choropleth":[{"type":"choropleth","colorbar":{"outlinewidth":0,"ticks":""}}],"histogram2d":[{"type":"histogram2d","colorbar":{"outlinewidth":0,"ticks":""},"colorscale":[[0.0,"#0d0887"],[0.1111111111111111,"#46039f"],[0.2222222222222222,"#7201a8"],[0.3333333333333333,"#9c179e"],[0.4444444444444444,"#bd3786"],[0.5555555555555556,"#d8576b"],[0.6666666666666666,"#ed7953"],[0.7777777777777778,"#fb9f3a"],[0.8888888888888888,"#fdca26"],[1.0,"#f0f921"]]}],"heatmap":[{"type":"heatmap","colorbar":{"outlinewidth":0,"ticks":""},"colorscale":[[0.0,"#0d0887"],[0.1111111111111111,"#46039f"],[0.2222222222222222,"#7201a8"],[0.3333333333333333,"#9c179e"],[0.4444444444444444,"#bd3786"],[0.5555555555555556,"#d8576b"],[0.6666666666666666,"#ed7953"],[0.7777777777777778,"#fb9f3a"],[0.8888888888888888,"#fdca26"],[1.0,"#f0f921"]]}],"heatmapgl":[{"type":"heatmapgl","colorbar":{"outlinewidth":0,"ticks":""},"colorscale":[[0.0,"#0d0887"],[0.1111111111111111,"#46039f"],[0.2222222222222222,"#7201a8"],[0.3333333333333333,"#9c179e"],[0.4444444444444444,"#bd3786"],[0.5555555555555556,"#d8576b"],[0.6666666666666666,"#ed7953"],[0.7777777777777778,"#fb9f3a"],[0.8888888888888888,"#fdca26"],[1.0,"#f0f921"]]}],"contourcarpet":[{"type":"contourcarpet","colorbar":{"outlinewidth":0,"ticks":""}}],"contour":[{"type":"contour","colorbar":{"outlinewidth":0,"ticks":""},"colorscale":[[0.0,"#0d0887"],[0.1111111111111111,"#46039f"],[0.2222222222222222,"#7201a8"],[0.3333333333333333,"#9c179e"],[0.4444444444444444,"#bd3786"],[0.5555555555555556,"#d8576b"],[0.6666666666666666,"#ed7953"],[0.7777777777777778,"#fb9f3a"],[0.8888888888888888,"#fdca26"],[1.0,"#f0f921"]]}],"surface":[{"type":"surface","colorbar":{"outlinewidth":0,"ticks":""},"colorscale":[[0.0,"#0d0887"],[0.1111111111111111,"#46039f"],[0.2222222222222222,"#7201a8"],[0.3333333333333333,"#9c
}
]
]);
await page.goto("http://localhost:3000");
await page
.locator("text=Plot Type MatplotlibPlotlyBokeh >> select")
.selectOption("Plotly");
await page
.locator("text=Month JanuaryFebruaryMarchAprilMay >> select")
.selectOption("January");
await page.locator('label:has-text("Social Distancing?")').click();
const submit_button = await page.locator("text=Submit");
await Promise.all([
submit_button.click(),
page.waitForResponse("**/api/predict/")
]);
await expect(page.locator(".js-plotly-plot")).toHaveCount(1);
});