Merge pull request #574 from gradio-app/Format-The-Codebase

Format-The-Codebase
This commit is contained in:
Ömer Faruk Özdemir 2022-02-10 11:18:52 +03:00 committed by GitHub
commit c9298b3802
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
29 changed files with 121 additions and 110 deletions

View File

@ -42,6 +42,18 @@ jobs:
. venv/bin/activate
coverage run -m pytest
coverage xml
- run:
command: |
. venv/bin/activate
python -m black --check gradio test
- run:
command: |
. venv/bin/activate
python -m isort --profile=black --check-only gradio test
- run:
command: |
. venv/bin/activate
python -m flake8 --ignore=E731,E501,E722,W503,E126,F401,E203 gradio test
- codecov/upload:
file: 'coverage.xml'
- store_artifacts:

View File

@ -68,5 +68,8 @@ All PRs should be against `master`. Direct commits to master are blocked, and PR
* A maintainer (@abidlabs, @aliabid94, @aliabd, @AK391, or @dawoodkhan82) is tagged in the PR comments and asked to complete a review
We ask that you make sure initial CI checks are passing before requesting a review. One of the Gradio maintainers will merge the PR when all the checks are passing.
Do not forget the format the codebase before pushing.
```
bash scripts/run_frontend.sh
```
*Could these guidelines be clearer? Feel free to open a PR to help us faciltiate open-source contributions!*

View File

@ -1,9 +1,14 @@
import pkg_resources
from gradio.routes import get_state, set_state
from gradio.flagging import FlaggingCallback, SimpleCSVLogger, CSVLogger, HuggingFaceDatasetSaver
from gradio.flagging import (
CSVLogger,
FlaggingCallback,
HuggingFaceDatasetSaver,
SimpleCSVLogger,
)
from gradio.interface import Interface, close_all, reset_all
from gradio.mix import Parallel, Series
from gradio.routes import get_state, set_state
current_pkg_version = pkg_resources.require("gradio")[0].version
__version__ = current_pkg_version

View File

@ -1,8 +1,6 @@
import base64
import json
import re
import tempfile
from pydantic import MissingError
import requests
@ -252,7 +250,7 @@ def load_interface(name, src=None, api_key=None, alias=None):
) # Separate the source (e.g. "huggingface") from the repo name (e.g. "google/vit-base-patch16-224")
assert (
len(tokens) > 1
), "Either `src` parameter must be provided, or `name` must be formatted as \{src\}/\{repo name\}"
), "Either `src` parameter must be provided, or `name` must be formatted as {src}/{repo name}"
src = tokens[0]
name = "/".join(tokens[1:])
assert src.lower() in repos, "parameter: src must be one of {}".format(repos.keys())
@ -261,7 +259,7 @@ def load_interface(name, src=None, api_key=None, alias=None):
def interface_params_from_config(config_dict):
## instantiate input component and output component
# instantiate input component and output component
config_dict["inputs"] = [
inputs.get_input_instance(component)
for component in config_dict["input_components"]

View File

@ -6,12 +6,11 @@ automatically added to a registry, which allows them to be easily referenced in
from __future__ import annotations
import os
import json
import math
import os
import tempfile
import warnings
from numbers import Number
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
import numpy as np

View File

@ -1,5 +1,5 @@
"""
This is the core file in the `gradio` package, and defines the Interface class,
This is the core file in the `gradio` package, and defines the Interface class,
including various methods for constructing an interface and then launching it.
"""
@ -15,20 +15,13 @@ import time
import warnings
import weakref
import webbrowser
from logging import warning
from typing import TYPE_CHECKING, Any, Callable, List, Optional, Tuple
from markdown_it import MarkdownIt
from mdit_py_plugins.footnote import footnote_plugin
from gradio import (
encryptor,
interpretation,
networking, # type: ignore
queueing,
strings,
utils,
)
from gradio import networking # type: ignore
from gradio import encryptor, interpretation, queueing, strings, utils
from gradio.external import load_from_pipeline, load_interface # type: ignore
from gradio.flagging import CSVLogger, FlaggingCallback # type: ignore
from gradio.inputs import InputComponent
@ -359,7 +352,7 @@ class Interface:
)
if allow_flagging is None:
allow_flagging = os.getenv("GRADIO_ALLOW_FLAGGING", "manual")
if allow_flagging == True:
if allow_flagging is True:
warnings.warn(
"The `allow_flagging` parameter in `Interface` now"
"takes a string value ('auto', 'manual', or 'never')"
@ -368,7 +361,7 @@ class Interface:
self.allow_flagging = "manual"
elif allow_flagging == "manual":
self.allow_flagging = "manual"
elif allow_flagging == False:
elif allow_flagging is False:
warnings.warn(
"The `allow_flagging` parameter in `Interface` now"
"takes a string value ('auto', 'manual', or 'never')"
@ -703,7 +696,7 @@ class Interface:
server_port, path_to_local_server, app, server = networking.start_server(
self, server_name, server_port, ssl_keyfile, ssl_certfile
)
self.local_url = path_to_local_server
self.server_port = server_port
self.status = "RUNNING"

View File

@ -78,7 +78,7 @@ def start_server(
server_name: Optional[str] = None,
server_port: Optional[int] = None,
ssl_keyfile: Optional[str] = None,
ssl_certfile: Optional[str] = None,
ssl_certfile: Optional[str] = None,
) -> Tuple[int, str, fastapi.FastAPI, threading.Thread, None]:
"""Launches a local server running the provided Interface
Parameters:
@ -109,14 +109,16 @@ def start_server(
port = server_port
url_host_name = "localhost" if server_name == "0.0.0.0" else server_name
if ssl_keyfile is not None:
if ssl_certfile is None:
raise ValueError("ssl_certfile must be provided if ssl_keyfile is provided.")
raise ValueError(
"ssl_certfile must be provided if ssl_keyfile is provided."
)
path_to_local_server = "https://{}:{}/".format(url_host_name, port)
else:
path_to_local_server = "http://{}:{}/".format(url_host_name, port)
auth = interface.auth
if auth is not None:
if not callable(auth):
@ -141,34 +143,31 @@ def start_server(
if interface.save_to is not None: # Used for selenium tests
interface.save_to["port"] = port
config = uvicorn.Config(app=app, port=port, host=server_name, log_level="warning",
ssl_keyfile=ssl_keyfile, ssl_certfile=ssl_certfile)
config = uvicorn.Config(
app=app,
port=port,
host=server_name,
log_level="warning",
ssl_keyfile=ssl_keyfile,
ssl_certfile=ssl_certfile,
)
server = Server(config=config)
server.run_in_thread()
return port, path_to_local_server, app, server
def setup_tunnel(local_server_port: int, endpoint: str) -> str:
response = url_request(
response = requests.get(
endpoint + "/v1/tunnel-request" if endpoint is not None else GRADIO_API_SERVER
)
if response and response.code == 200:
if response and response.status_code == 200:
try:
payload = json.loads(response.read().decode("utf-8"))[0]
payload = response.json()[0]
return create_tunnel(payload, LOCALHOST_NAME, local_server_port)
except Exception as e:
raise RuntimeError(str(e))
def url_request(url: str) -> Optional[http.client.HTTPResponse]:
try:
req = urllib.request.Request(
url=url, headers={"content-type": "application/json"}
)
res = urllib.request.urlopen(req, timeout=10)
return res
except Exception as e:
raise RuntimeError(str(e))
else:
raise RuntimeError("Could not get share link from Gradio API Server.")
def url_ok(url: str) -> bool:

View File

@ -13,7 +13,7 @@ import tempfile
import warnings
from numbers import Number
from types import ModuleType
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple
from typing import TYPE_CHECKING, Dict, List, Optional
import numpy as np
import pandas as pd

View File

@ -6,7 +6,7 @@ from __future__ import annotations
import csv
import os
import shutil
from typing import Any, List, Tuple, TYPE_CHECKING
from typing import TYPE_CHECKING, Any, List, Tuple
from gradio.flagging import CSVLogger

View File

@ -16,6 +16,7 @@ with warnings.catch_warnings():
warnings.simplefilter("ignore") # Ignore pydub warning if ffmpeg is not installed
from pydub import AudioSegment
#########################
# IMAGE PRE-PROCESSING
#########################
@ -336,12 +337,6 @@ def _convert(image, dtype, force_copy=False, uniform=False):
"""
kind = a.dtype.kind
if n > m and a.max() < 2**m:
mnew = int(np.ceil(m / 2) * 2)
if mnew > m:
dtype = "int{}".format(mnew)
else:
dtype = "uint{}".format(mnew)
n = int(np.ceil(n / 2) * 2)
return a.astype(_dtype_bits(kind, m))
elif n == m:
return a.copy() if copy else a

View File

@ -2,8 +2,8 @@ import json
import os
import sqlite3
import time
from typing import Dict, Tuple
import uuid
from typing import Dict, Tuple
import requests
@ -26,7 +26,7 @@ def queue_thread(path_to_local_server: str) -> None:
fail_job(hash, response.text)
else:
time.sleep(1)
except Exception as e:
except:
time.sleep(1)
pass

View File

@ -45,7 +45,6 @@ app.add_middleware(
allow_headers=["*"],
)
templates = Jinja2Templates(directory=STATIC_TEMPLATE_LIB)
@ -198,7 +197,7 @@ async def predict(request: Request, username: str = Depends(get_current_user)):
body = await request.json()
flag_index = None
if body.get("example_id") != None:
if body.get("example_id") is not None:
example_id = body["example_id"]
if app.interface.cache_examples:
prediction = await run_in_threadpool(

View File

@ -82,7 +82,7 @@ def create_tunnel(payload, local_server, local_server_port):
)
except Exception as e:
print(
"*** Failed to connect to {}:{}: {}}".format(
"*** Failed to connect to {}:{}: {}".format(
payload["host"], int(payload["port"]), e
)
)

View File

@ -22,7 +22,6 @@ import gradio
if TYPE_CHECKING: # Only import for type checking (is False at runtime).
from gradio import Interface
analytics_url = "https://api.gradio.app/"
PKG_VERSION_URL = "https://api.gradio.app/pkg-version"
analytics.write_key = "uxIFddIEuuUcFLf9VgH2teTEtPlWdkNy"

10
scripts/format.sh Normal file
View File

@ -0,0 +1,10 @@
#!/bin/bash
if [ -z "$(ls | grep CONTRIBUTING.md)" ]; then
echo "Please run the script from repo directory"
exit -1
else
echo "Installing formatting with black and isort, also checking for standards with flake8"
python -m black gradio test
python -m isort --profile=black gradio test
python -m flake8 --ignore=E731,E501,E722,W503,E126,F401,E203 gradio test
fi

View File

@ -4,7 +4,6 @@ import unittest
from gradio import encryptor, processing_utils
from gradio.test_data import BASE64_IMAGE
os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"

View File

@ -7,8 +7,8 @@ import transformers
import gradio as gr
"""
WARNING: These tests have an external dependency: namely that Hugging Face's
Hub and Space APIs do not change, and they keep their most famous models up.
WARNING: These tests have an external dependency: namely that Hugging Face's
Hub and Space APIs do not change, and they keep their most famous models up.
So if, e.g. Spaces is down, then these test will not pass.
"""
@ -243,11 +243,9 @@ class TestLoadInterface(unittest.TestCase):
class TestLoadFromPipeline(unittest.TestCase):
def test_question_answering(self):
def test_text_to_text_model_from_pipeline(self):
pipe = transformers.pipeline(model="sshleifer/bart-tiny-random")
output = pipe(
"My name is Sylvain and I work at Hugging Face in Brooklyn"
)
output = pipe("My name is Sylvain and I work at Hugging Face in Brooklyn")
self.assertIsNotNone(output)

View File

@ -1,14 +1,12 @@
from difflib import SequenceMatcher
import json
import os
import tempfile
import unittest
from re import sub
from difflib import SequenceMatcher
import numpy as np
import pandas
import PIL
from pydub import AudioSegment
import gradio as gr
@ -668,8 +666,8 @@ class TestDataframe(unittest.TestCase):
self.assertEqual(iface.process([x_data])[0], [6])
x_data = [["Tim"], ["Jon"], ["Sal"]]
def get_last(l):
return l[-1]
def get_last(my_list):
return my_list[-1]
iface = gr.Interface(get_last, "list", "text")
self.assertEqual(iface.process([x_data])[0], ["Sal"])

View File

@ -1,8 +1,5 @@
import io
import socket
import sys
import tempfile
import threading
import unittest
import unittest.mock as mock
from contextlib import contextmanager
@ -10,9 +7,8 @@ from contextlib import contextmanager
import mlflow
import requests
import wandb
from comet_ml import Experiment
from gradio.interface import *
from gradio.interface import Interface, close_all, os
os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"

View File

@ -6,7 +6,7 @@ import numpy as np
import gradio.interpretation
import gradio.test_data
from gradio import Interface
from gradio.processing_utils import decode_base64_to_image, encode_array_to_base64
from gradio.processing_utils import decode_base64_to_image
os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"

View File

@ -9,7 +9,7 @@ import warnings
import aiohttp
from fastapi.testclient import TestClient
from gradio import Interface, flagging, networking, reset_all, queueing
from gradio import Interface, flagging, networking
os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"
@ -94,11 +94,6 @@ class TestInterpretation(unittest.TestCase):
class TestURLs(unittest.TestCase):
def test_url_ok(self):
urllib.request.urlopen = mock.MagicMock(return_value="test")
res = networking.url_request("http://www.gradio.app")
self.assertEqual(res, "test")
def test_setup_tunnel(self):
networking.create_tunnel = mock.MagicMock(return_value="test")
res = networking.setup_tunnel(None, None)

View File

@ -138,7 +138,7 @@ class TestImage(unittest.TestCase):
xpoints = np.array([0, 6])
ypoints = np.array([0, 250])
fig = plt.figure()
p = plt.plot(xpoints, ypoints)
plt.plot(xpoints, ypoints)
self.assertTrue(
plot_output.postprocess(fig).startswith("data:image/png;base64,")
)

View File

@ -1,9 +1,7 @@
import os
import pathlib
import tempfile
import unittest
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
@ -38,7 +36,8 @@ class ImagePreprocessing(unittest.TestCase):
)
self.assertEqual(output_base64, gr.test_data.BASE64_IMAGE)
# def test_encode_plot_to_base64(self): # Commented out because this is throwing errors on Windows. Possibly due to different matplotlib behavior on Windows?
# Commented out because this is throwing errors on Windows. Possibly due to different matplotlib behavior on Windows?
# def test_encode_plot_to_base64(self):
# plt.plot([1, 2, 3, 4])
# output_base64 = gr.processing_utils.encode_plot_to_base64(plt)
# self.assertEqual(output_base64, gr.test_data.BASE64_PLT_IMG)

View File

@ -2,11 +2,8 @@
import os
import unittest
import unittest.mock as mock
import requests
from gradio import Interface, queueing
from gradio import queueing
os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"

View File

@ -6,7 +6,7 @@ import unittest.mock as mock
from fastapi.testclient import TestClient
from gradio import Interface, flagging, networking, queueing, reset_all
from gradio import Interface, queueing, reset_all
os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"
@ -51,7 +51,7 @@ class TestRoutes(unittest.TestCase):
)
self.assertEqual(response.status_code, 200)
def test_queue_push_route(self):
def test_queue_push_route_2(self):
queueing.get_status = mock.MagicMock(return_value=(None, None))
response = self.client.post("/api/queue/status/", json={"hash": "test"})
self.assertEqual(response.status_code, 200)

View File

@ -7,6 +7,7 @@ import unittest
import unittest.mock as mock
import paramiko
import requests
from gradio import Interface, networking, tunneling
@ -15,8 +16,8 @@ os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"
class TestTunneling(unittest.TestCase):
def test_create_tunnel(self):
response = networking.url_request(networking.GRADIO_API_SERVER)
payload = json.loads(response.read().decode("utf-8"))[0]
response = requests.get(networking.GRADIO_API_SERVER)
payload = response.json()[0]
io = Interface(lambda x: x, "text", "text")
_, path_to_local_server, _ = io.launch(prevent_thread_lock=True, share=False)
_, localhost, port = path_to_local_server.split(":")
@ -34,7 +35,7 @@ class TestVerbose(unittest.TestCase):
def setUp(self):
self.message = "print test"
self.capturedOutput = io.StringIO() # Create StringIO object
sys.stdout = self.capturedOutput # and redirect stdout.
sys.stdout = self.capturedOutput # and redirect stdout.
def test_verbose_debug_true(self):
tunneling.verbose(self.message, debug_mode=True)

View File

@ -7,8 +7,16 @@ import warnings
import pkg_resources
import requests
import gradio
from gradio.utils import *
from gradio.utils import (
colab_check,
error_analytics,
get_local_ip_address,
ipython_check,
json,
launch_analytics,
readme_to_html,
version_check,
)
os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"

View File

@ -7,13 +7,12 @@ import re
import markdown2
import requests
from jinja2 import Template
from render_html_helpers import generate_meta_image
from gradio.inputs import InputComponent
from gradio.interface import Interface
from gradio.outputs import OutputComponent
from render_html_helpers import generate_meta_image
GRADIO_DIR = "../../"
GRADIO_GUIDES_DIR = os.path.join(GRADIO_DIR, "guides")
GRADIO_DEMO_DIR = os.path.join(GRADIO_DIR, "demo")
@ -28,7 +27,9 @@ def render_index():
tweets = json.load(tweets_file)
star_request = requests.get("https://api.github.com/repos/gradio-app/gradio").json()
star_count = (
"{:,}".format(star_request["stargazers_count"]) if "stargazers_count" in star_request else ""
"{:,}".format(star_request["stargazers_count"])
if "stargazers_count" in star_request
else ""
)
with open("src/index_template.html", encoding="utf-8") as template_file:
template = Template(template_file.read())
@ -64,7 +65,13 @@ for guide in sorted(os.listdir(GRADIO_GUIDES_DIR)):
spaces = guide_content.split("related_spaces: ")[1].split("\n")[0].split(", ")
url = f"https://gradio.app/{guide_name}/"
guide_content = "\n".join([line for line in guide_content.split("\n") if not (line.startswith("tags: ") or line.startswith("related_spaces: "))])
guide_content = "\n".join(
[
line
for line in guide_content.split("\n")
if not (line.startswith("tags: ") or line.startswith("related_spaces: "))
]
)
guides.append(
{

View File

@ -1,6 +1,8 @@
import cairo
import os
import cairo
def add_line_breaks(text, num_char):
if len(text) > num_char:
text_list = text.split()
@ -8,7 +10,7 @@ def add_line_breaks(text, num_char):
total_count = 0
count = 0
for word in text_list:
if total_count > num_char*5:
if total_count > num_char * 5:
text = text[:-1]
text += "..."
break
@ -19,9 +21,10 @@ def add_line_breaks(text, num_char):
count = 0
else:
text += word + " "
total_count += len(word+" ")
total_count += len(word + " ")
return text
return text
return text
def generate_meta_image(guide):
IMG_GUIDE_LOCATION = "dist/assets/img/guides"
@ -29,9 +32,8 @@ def generate_meta_image(guide):
surface = cairo.ImageSurface.create_from_png("src/assets/img/guides/base-image.png")
ctx = cairo.Context(surface)
ctx.scale(500, 500)
ctx.set_source_rgba(0.611764706,0.639215686,0.6862745098,1)
ctx.select_font_face("Arial", cairo.FONT_SLANT_NORMAL,
cairo.FONT_WEIGHT_NORMAL)
ctx.set_source_rgba(0.611764706, 0.639215686, 0.6862745098, 1)
ctx.select_font_face("Arial", cairo.FONT_SLANT_NORMAL, cairo.FONT_WEIGHT_NORMAL)
ctx.set_font_size(0.15)
ctx.move_to(0.3, 0.55)
ctx.show_text("gradio.app/guides")
@ -41,19 +43,18 @@ def generate_meta_image(guide):
tags = " | ".join(tags)
ctx.move_to(0.3, 2.2)
ctx.show_text(tags)
ctx.set_source_rgba(0.352941176,0.352941176,0.352941176,1)
ctx.set_source_rgba(0.352941176, 0.352941176, 0.352941176, 1)
ctx.set_font_size(0.28)
title_breaked = add_line_breaks(title, 10)
if "\n" in title_breaked:
for i, t in enumerate(title_breaked.split("\n")):
ctx.move_to(0.3, 0.9+i*0.4)
ctx.move_to(0.3, 0.9 + i * 0.4)
ctx.show_text(t)
else:
ctx.move_to(0.3, 1.0)
ctx.show_text(title_breaked)
os.makedirs(IMG_GUIDE_LOCATION, exist_ok=True )
os.makedirs(IMG_GUIDE_LOCATION, exist_ok=True)
image_path = f"{IMG_GUIDE_LOCATION}/{guide_name}.png"
surface.write_to_png(image_path)