Merge branch 'master' into Dev-Requirements

This commit is contained in:
Ömer Faruk Özdemir 2022-01-25 09:17:36 +03:00
commit ea0f78a85e
11 changed files with 101 additions and 19 deletions

View File

@ -1,6 +1,6 @@
Metadata-Version: 2.1
Name: gradio
Version: 2.7.5
Version: 2.7.5.2
Summary: Python library for easily interacting with trained machine learning models
Home-page: https://github.com/gradio-app/gradio-UI
Author: Abubakar Abid, Ali Abid, Ali Abdalla, Dawood Khan, Ahsen Khaliq

View File

@ -3,6 +3,7 @@
from __future__ import annotations
import inspect
import io
import os
import posixpath
import secrets
@ -20,7 +21,7 @@ from fastapi.security import OAuth2PasswordRequestForm
from fastapi.templating import Jinja2Templates
from starlette.responses import RedirectResponse
from gradio import queueing, utils
from gradio import encryptor, queueing, utils
from gradio.process_examples import load_from_cache, process_example
STATIC_TEMPLATE_LIB = pkg_resources.resource_filename("gradio", "templates/")
@ -126,6 +127,21 @@ def static_resource(path: str):
raise HTTPException(status_code=404, detail="Static file not found")
@app.get("/file/{path:path}", dependencies=[Depends(login_check)])
def file(path):
if app.interface.encrypt and isinstance(
app.interface.examples, str) and path.startswith(
app.interface.examples):
with open(safe_join(app.cwd, path), "rb") as encrypted_file:
encrypted_data = encrypted_file.read()
file_data = encryptor.decrypt(
app.interface.encryption_key, encrypted_data)
return FileResponse(
io.BytesIO(file_data), attachment_filename=os.path.basename(path))
else:
return FileResponse(safe_join(app.cwd, path))
@app.get("/api", response_class=HTMLResponse) # Needed for Spaces
@app.get("/api/", response_class=HTMLResponse)
def api_docs(request: Request):
@ -244,9 +260,8 @@ async def interpret(request: Request):
@app.post("/api/queue/push/", dependencies=[Depends(login_check)])
async def queue_push(request: Request):
body = await request.json()
data = body["data"]
action = body["action"]
job_hash, queue_position = queueing.push({"data": data}, action)
job_hash, queue_position = queueing.push(body, action)
return {"hash": job_hash, "queue_position": queue_position}

View File

@ -1,5 +1,6 @@
import os
import shutil
from typing import Any, Dict
from gradio import processing_utils
@ -32,7 +33,13 @@ class Component:
"""
return {}
def save_flagged(self, dir, label, data, encryption_key):
def save_flagged(
self,
dir: str,
label: str,
data: Any,
encryption_key: bool
) -> Any:
"""
Saves flagged data from component
"""
@ -44,7 +51,16 @@ class Component:
"""
return data
def save_flagged_file(self, dir, label, data, encryption_key):
def save_flagged_file(
self,
dir: str,
label: str,
data: Any,
encryption_key: bool
) -> str:
"""
Saved flagged data (e.g. image or audio) as a file and returns filepath
"""
if data is None:
return None
file = processing_utils.decode_base64_to_file(data, encryption_key)
@ -64,7 +80,15 @@ class Component:
shutil.move(old_file_name, os.path.join(dir, label, new_file_name))
return label + "/" + new_file_name
def restore_flagged_file(self, dir, file, encryption_key):
def restore_flagged_file(
self,
dir: str,
file: str,
encryption_key: bool,
) -> Dict[str, Any]:
"""
Loads flagged data from file and returns it
"""
data = processing_utils.encode_file_to_base64(
os.path.join(dir, file), encryption_key=encryption_key
)

View File

@ -3,12 +3,19 @@ from Crypto.Cipher import AES
from Crypto.Hash import SHA256
def get_key(password):
def get_key(
password: str
) -> bytes:
"""Generates an encryption key based on the password provided."""
key = SHA256.new(password.encode()).digest()
return key
def encrypt(key, source):
def encrypt(
key: bytes,
source: bytes
) -> bytes:
"""Encrypts source data using the provided encryption key"""
IV = Random.new().read(AES.block_size) # generate IV
encryptor = AES.new(key, AES.MODE_CBC, IV)
padding = AES.block_size - len(source) % AES.block_size # calculate needed padding
@ -17,7 +24,10 @@ def encrypt(key, source):
return data
def decrypt(key, source):
def decrypt(
key: bytes,
source: bytes
) -> bytes:
IV = source[: AES.block_size] # extract the IV from the beginning
decryptor = AES.new(key, AES.MODE_CBC, IV)
data = decryptor.decrypt(source[AES.block_size :]) # decrypt

View File

@ -109,9 +109,8 @@ class SimpleCSVLogger(FlaggingCallback):
class CSVLogger(FlaggingCallback):
"""
The default implementation of the FlaggingCallback abstract class.
Logs the input and output data to a CSV file.
Logs the input and output data to a CSV file. Supports encryption.
"""
def setup(self, flagging_dir: str):
self.flagging_dir = flagging_dir
os.makedirs(flagging_dir, exist_ok=True)

View File

@ -244,8 +244,8 @@ class Interface:
self.description = description
if article is not None:
article = utils.readme_to_html(article)
article = markdown2.markdown(article, extras=["fenced-code-blocks"])
article = markdown2.markdown(
article, extras=["fenced-code-blocks"])
self.article = article
self.thumbnail = thumbnail

View File

View File

@ -1 +1 @@
2.7.5
2.7.5.2

View File

@ -5,7 +5,7 @@ except ImportError:
setup(
name="gradio",
version="2.7.5",
version="2.7.5.2",
include_package_data=True,
description="Python library for easily interacting with trained machine learning models",
author="Abubakar Abid, Ali Abid, Ali Abdalla, Dawood Khan, Ahsen Khaliq",

33
test/test_encryptor.py Normal file
View File

@ -0,0 +1,33 @@
import os
import unittest
from gradio import encryptor, processing_utils
from gradio.test_data import BASE64_IMAGE
os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"
class TestKeyGenerator(unittest.TestCase):
def test_same_pass(self):
key1 = encryptor.get_key("test")
key2 = encryptor.get_key("test")
self.assertEquals(key1, key2)
def test_diff_pass(self):
key1 = encryptor.get_key("test")
key2 = encryptor.get_key("diff_test")
self.assertNotEquals(key1, key2)
class TestEncryptorDecryptor(unittest.TestCase):
def test_same_pass(self):
key = encryptor.get_key("test")
data, _ = processing_utils.decode_base64_to_binary(BASE64_IMAGE)
encrypted_data = encryptor.encrypt(key, data)
decrypted_data = encryptor.decrypt(key, encrypted_data)
self.assertEquals(data, decrypted_data)
if __name__ == "__main__":
unittest.main()

View File

@ -1,12 +1,11 @@
import tempfile
import unittest
import unittest.mock as mock
import gradio as gr
from gradio import flagging
class TestFlagging(unittest.TestCase):
class TestDefaultFlagging(unittest.TestCase):
def test_default_flagging_handler(self):
with tempfile.TemporaryDirectory() as tmpdirname:
io = gr.Interface(lambda x: x, "text", "text", flagging_dir=tmpdirname)
@ -17,6 +16,8 @@ class TestFlagging(unittest.TestCase):
self.assertEqual(row_count, 2) # 3 rows written including header
io.close()
class TestSimpleFlagging(unittest.TestCase):
def test_simple_csv_flagging_handler(self):
with tempfile.TemporaryDirectory() as tmpdirname:
io = gr.Interface(
@ -33,6 +34,6 @@ class TestFlagging(unittest.TestCase):
self.assertEqual(row_count, 1) # no header
io.close()
if __name__ == "__main__":
unittest.main()