2
0
mirror of https://github.com/gradio-app/gradio.git synced 2025-03-31 12:20:26 +08:00

Format-The-Codebase

- format the codebase
- add format checkers to circleci
This commit is contained in:
Ömer Faruk Özdemir 2022-02-09 10:40:05 +03:00
parent 7619ae76d1
commit 87d7fbee61
17 changed files with 64 additions and 56 deletions

@ -37,6 +37,12 @@ jobs:
- run:
command: |
mkdir screenshots
- run:
command: |
black --check gradio test
isort --check-only gradio test
flake8 --max-line-length=160 gradio test
- run:
command: |
. venv/bin/activate

@ -6,19 +6,11 @@ import math
import numpy as np
import torch
from pytorch_transformers import (
WEIGHTS_NAME,
BertConfig,
BertForQuestionAnswering,
BertTokenizer,
)
from pytorch_transformers import (WEIGHTS_NAME, BertConfig,
BertForQuestionAnswering, BertTokenizer)
from torch.utils.data import DataLoader, SequentialSampler, TensorDataset
from utils import (
get_answer,
input_to_squad_example,
squad_examples_to_features,
to_list,
)
from utils import (get_answer, input_to_squad_example,
squad_examples_to_features, to_list)
RawResult = collections.namedtuple(
"RawResult", ["unique_id", "start_logits", "end_logits"]

@ -5,7 +5,8 @@ import math
import numpy as np
import torch
from pytorch_transformers.tokenization_bert import BasicTokenizer, whitespace_tokenize
from pytorch_transformers.tokenization_bert import (BasicTokenizer,
whitespace_tokenize)
from torch.utils.data import DataLoader, SequentialSampler, TensorDataset

@ -1,9 +1,10 @@
import pkg_resources
from gradio.routes import get_state, set_state
from gradio.flagging import FlaggingCallback, SimpleCSVLogger, CSVLogger, HuggingFaceDatasetSaver
from gradio.flagging import (CSVLogger, FlaggingCallback,
HuggingFaceDatasetSaver, SimpleCSVLogger)
from gradio.interface import Interface, close_all, reset_all
from gradio.mix import Parallel, Series
from gradio.routes import get_state, set_state
current_pkg_version = pkg_resources.require("gradio")[0].version
__version__ = current_pkg_version

@ -2,9 +2,9 @@ import base64
import json
import re
import tempfile
from pydantic import MissingError
import requests
from pydantic import MissingError
from gradio import inputs, outputs

@ -6,9 +6,9 @@ automatically added to a registry, which allows them to be easily referenced in
from __future__ import annotations
import os
import json
import math
import os
import tempfile
import warnings
from numbers import Number

@ -21,14 +21,8 @@ from typing import TYPE_CHECKING, Any, Callable, List, Optional, Tuple
from markdown_it import MarkdownIt
from mdit_py_plugins.footnote import footnote_plugin
from gradio import (
encryptor,
interpretation,
networking, # type: ignore
queueing,
strings,
utils,
)
from gradio import networking # type: ignore
from gradio import encryptor, interpretation, queueing, strings, utils
from gradio.external import load_from_pipeline, load_interface # type: ignore
from gradio.flagging import CSVLogger, FlaggingCallback # type: ignore
from gradio.inputs import InputComponent
@ -703,7 +697,7 @@ class Interface:
server_port, path_to_local_server, app, server = networking.start_server(
self, server_name, server_port, ssl_keyfile, ssl_certfile
)
self.local_url = path_to_local_server
self.server_port = server_port
self.status = "RUNNING"

@ -78,7 +78,7 @@ def start_server(
server_name: Optional[str] = None,
server_port: Optional[int] = None,
ssl_keyfile: Optional[str] = None,
ssl_certfile: Optional[str] = None,
ssl_certfile: Optional[str] = None,
) -> Tuple[int, str, fastapi.FastAPI, threading.Thread, None]:
"""Launches a local server running the provided Interface
Parameters:
@ -109,14 +109,16 @@ def start_server(
port = server_port
url_host_name = "localhost" if server_name == "0.0.0.0" else server_name
if ssl_keyfile is not None:
if ssl_certfile is None:
raise ValueError("ssl_certfile must be provided if ssl_keyfile is provided.")
raise ValueError(
"ssl_certfile must be provided if ssl_keyfile is provided."
)
path_to_local_server = "https://{}:{}/".format(url_host_name, port)
else:
path_to_local_server = "http://{}:{}/".format(url_host_name, port)
auth = interface.auth
if auth is not None:
if not callable(auth):
@ -141,8 +143,14 @@ def start_server(
if interface.save_to is not None: # Used for selenium tests
interface.save_to["port"] = port
config = uvicorn.Config(app=app, port=port, host=server_name, log_level="warning",
ssl_keyfile=ssl_keyfile, ssl_certfile=ssl_certfile)
config = uvicorn.Config(
app=app,
port=port,
host=server_name,
log_level="warning",
ssl_keyfile=ssl_keyfile,
ssl_certfile=ssl_certfile,
)
server = Server(config=config)
server.run_in_thread()
return port, path_to_local_server, app, server

@ -6,7 +6,7 @@ from __future__ import annotations
import csv
import os
import shutil
from typing import Any, List, Tuple, TYPE_CHECKING
from typing import TYPE_CHECKING, Any, List, Tuple
from gradio.flagging import CSVLogger

@ -2,8 +2,8 @@ import json
import os
import sqlite3
import time
from typing import Dict, Tuple
import uuid
from typing import Dict, Tuple
import requests

@ -4,7 +4,6 @@ import unittest
from gradio import encryptor, processing_utils
from gradio.test_data import BASE64_IMAGE
os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"

@ -245,9 +245,7 @@ class TestLoadInterface(unittest.TestCase):
class TestLoadFromPipeline(unittest.TestCase):
def test_question_answering(self):
pipe = transformers.pipeline(model="sshleifer/bart-tiny-random")
output = pipe(
"My name is Sylvain and I work at Hugging Face in Brooklyn"
)
output = pipe("My name is Sylvain and I work at Hugging Face in Brooklyn")
self.assertIsNotNone(output)

@ -1,8 +1,8 @@
from difflib import SequenceMatcher
import json
import os
import tempfile
import unittest
from difflib import SequenceMatcher
from re import sub
import numpy as np

@ -6,7 +6,8 @@ import numpy as np
import gradio.interpretation
import gradio.test_data
from gradio import Interface
from gradio.processing_utils import decode_base64_to_image, encode_array_to_base64
from gradio.processing_utils import (decode_base64_to_image,
encode_array_to_base64)
os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"

@ -9,7 +9,7 @@ import warnings
import aiohttp
from fastapi.testclient import TestClient
from gradio import Interface, flagging, networking, reset_all, queueing
from gradio import Interface, flagging, networking, queueing, reset_all
os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"

@ -7,13 +7,12 @@ import re
import markdown2
import requests
from jinja2 import Template
from render_html_helpers import generate_meta_image
from gradio.inputs import InputComponent
from gradio.interface import Interface
from gradio.outputs import OutputComponent
from render_html_helpers import generate_meta_image
GRADIO_DIR = "../../"
GRADIO_GUIDES_DIR = os.path.join(GRADIO_DIR, "guides")
GRADIO_DEMO_DIR = os.path.join(GRADIO_DIR, "demo")
@ -28,7 +27,9 @@ def render_index():
tweets = json.load(tweets_file)
star_request = requests.get("https://api.github.com/repos/gradio-app/gradio").json()
star_count = (
"{:,}".format(star_request["stargazers_count"]) if "stargazers_count" in star_request else ""
"{:,}".format(star_request["stargazers_count"])
if "stargazers_count" in star_request
else ""
)
with open("src/index_template.html", encoding="utf-8") as template_file:
template = Template(template_file.read())
@ -64,7 +65,13 @@ for guide in sorted(os.listdir(GRADIO_GUIDES_DIR)):
spaces = guide_content.split("related_spaces: ")[1].split("\n")[0].split(", ")
url = f"https://gradio.app/{guide_name}/"
guide_content = "\n".join([line for line in guide_content.split("\n") if not (line.startswith("tags: ") or line.startswith("related_spaces: "))])
guide_content = "\n".join(
[
line
for line in guide_content.split("\n")
if not (line.startswith("tags: ") or line.startswith("related_spaces: "))
]
)
guides.append(
{

@ -1,6 +1,8 @@
import cairo
import os
import cairo
def add_line_breaks(text, num_char):
if len(text) > num_char:
text_list = text.split()
@ -8,7 +10,7 @@ def add_line_breaks(text, num_char):
total_count = 0
count = 0
for word in text_list:
if total_count > num_char*5:
if total_count > num_char * 5:
text = text[:-1]
text += "..."
break
@ -19,9 +21,10 @@ def add_line_breaks(text, num_char):
count = 0
else:
text += word + " "
total_count += len(word+" ")
total_count += len(word + " ")
return text
return text
return text
def generate_meta_image(guide):
IMG_GUIDE_LOCATION = "dist/assets/img/guides"
@ -29,9 +32,8 @@ def generate_meta_image(guide):
surface = cairo.ImageSurface.create_from_png("src/assets/img/guides/base-image.png")
ctx = cairo.Context(surface)
ctx.scale(500, 500)
ctx.set_source_rgba(0.611764706,0.639215686,0.6862745098,1)
ctx.select_font_face("Arial", cairo.FONT_SLANT_NORMAL,
cairo.FONT_WEIGHT_NORMAL)
ctx.set_source_rgba(0.611764706, 0.639215686, 0.6862745098, 1)
ctx.select_font_face("Arial", cairo.FONT_SLANT_NORMAL, cairo.FONT_WEIGHT_NORMAL)
ctx.set_font_size(0.15)
ctx.move_to(0.3, 0.55)
ctx.show_text("gradio.app/guides")
@ -41,19 +43,18 @@ def generate_meta_image(guide):
tags = " | ".join(tags)
ctx.move_to(0.3, 2.2)
ctx.show_text(tags)
ctx.set_source_rgba(0.352941176,0.352941176,0.352941176,1)
ctx.set_source_rgba(0.352941176, 0.352941176, 0.352941176, 1)
ctx.set_font_size(0.28)
title_breaked = add_line_breaks(title, 10)
if "\n" in title_breaked:
for i, t in enumerate(title_breaked.split("\n")):
ctx.move_to(0.3, 0.9+i*0.4)
ctx.move_to(0.3, 0.9 + i * 0.4)
ctx.show_text(t)
else:
ctx.move_to(0.3, 1.0)
ctx.show_text(title_breaked)
os.makedirs(IMG_GUIDE_LOCATION, exist_ok=True )
os.makedirs(IMG_GUIDE_LOCATION, exist_ok=True)
image_path = f"{IMG_GUIDE_LOCATION}/{guide_name}.png"
surface.write_to_png(image_path)