mirror of
https://github.com/gradio-app/gradio.git
synced 2025-04-06 12:30:29 +08:00
Merge branch 'master' into Dev-Requirements
This commit is contained in:
commit
ea0f78a85e
@ -1,6 +1,6 @@
|
||||
Metadata-Version: 2.1
|
||||
Name: gradio
|
||||
Version: 2.7.5
|
||||
Version: 2.7.5.2
|
||||
Summary: Python library for easily interacting with trained machine learning models
|
||||
Home-page: https://github.com/gradio-app/gradio-UI
|
||||
Author: Abubakar Abid, Ali Abid, Ali Abdalla, Dawood Khan, Ahsen Khaliq
|
||||
|
@ -3,6 +3,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import inspect
|
||||
import io
|
||||
import os
|
||||
import posixpath
|
||||
import secrets
|
||||
@ -20,7 +21,7 @@ from fastapi.security import OAuth2PasswordRequestForm
|
||||
from fastapi.templating import Jinja2Templates
|
||||
from starlette.responses import RedirectResponse
|
||||
|
||||
from gradio import queueing, utils
|
||||
from gradio import encryptor, queueing, utils
|
||||
from gradio.process_examples import load_from_cache, process_example
|
||||
|
||||
STATIC_TEMPLATE_LIB = pkg_resources.resource_filename("gradio", "templates/")
|
||||
@ -126,6 +127,21 @@ def static_resource(path: str):
|
||||
raise HTTPException(status_code=404, detail="Static file not found")
|
||||
|
||||
|
||||
@app.get("/file/{path:path}", dependencies=[Depends(login_check)])
|
||||
def file(path):
|
||||
if app.interface.encrypt and isinstance(
|
||||
app.interface.examples, str) and path.startswith(
|
||||
app.interface.examples):
|
||||
with open(safe_join(app.cwd, path), "rb") as encrypted_file:
|
||||
encrypted_data = encrypted_file.read()
|
||||
file_data = encryptor.decrypt(
|
||||
app.interface.encryption_key, encrypted_data)
|
||||
return FileResponse(
|
||||
io.BytesIO(file_data), attachment_filename=os.path.basename(path))
|
||||
else:
|
||||
return FileResponse(safe_join(app.cwd, path))
|
||||
|
||||
|
||||
@app.get("/api", response_class=HTMLResponse) # Needed for Spaces
|
||||
@app.get("/api/", response_class=HTMLResponse)
|
||||
def api_docs(request: Request):
|
||||
@ -244,9 +260,8 @@ async def interpret(request: Request):
|
||||
@app.post("/api/queue/push/", dependencies=[Depends(login_check)])
|
||||
async def queue_push(request: Request):
|
||||
body = await request.json()
|
||||
data = body["data"]
|
||||
action = body["action"]
|
||||
job_hash, queue_position = queueing.push({"data": data}, action)
|
||||
job_hash, queue_position = queueing.push(body, action)
|
||||
return {"hash": job_hash, "queue_position": queue_position}
|
||||
|
||||
|
||||
|
@ -1,5 +1,6 @@
|
||||
import os
|
||||
import shutil
|
||||
from typing import Any, Dict
|
||||
|
||||
from gradio import processing_utils
|
||||
|
||||
@ -32,7 +33,13 @@ class Component:
|
||||
"""
|
||||
return {}
|
||||
|
||||
def save_flagged(self, dir, label, data, encryption_key):
|
||||
def save_flagged(
|
||||
self,
|
||||
dir: str,
|
||||
label: str,
|
||||
data: Any,
|
||||
encryption_key: bool
|
||||
) -> Any:
|
||||
"""
|
||||
Saves flagged data from component
|
||||
"""
|
||||
@ -44,7 +51,16 @@ class Component:
|
||||
"""
|
||||
return data
|
||||
|
||||
def save_flagged_file(self, dir, label, data, encryption_key):
|
||||
def save_flagged_file(
|
||||
self,
|
||||
dir: str,
|
||||
label: str,
|
||||
data: Any,
|
||||
encryption_key: bool
|
||||
) -> str:
|
||||
"""
|
||||
Saved flagged data (e.g. image or audio) as a file and returns filepath
|
||||
"""
|
||||
if data is None:
|
||||
return None
|
||||
file = processing_utils.decode_base64_to_file(data, encryption_key)
|
||||
@ -64,7 +80,15 @@ class Component:
|
||||
shutil.move(old_file_name, os.path.join(dir, label, new_file_name))
|
||||
return label + "/" + new_file_name
|
||||
|
||||
def restore_flagged_file(self, dir, file, encryption_key):
|
||||
def restore_flagged_file(
|
||||
self,
|
||||
dir: str,
|
||||
file: str,
|
||||
encryption_key: bool,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Loads flagged data from file and returns it
|
||||
"""
|
||||
data = processing_utils.encode_file_to_base64(
|
||||
os.path.join(dir, file), encryption_key=encryption_key
|
||||
)
|
||||
|
@ -3,12 +3,19 @@ from Crypto.Cipher import AES
|
||||
from Crypto.Hash import SHA256
|
||||
|
||||
|
||||
def get_key(password):
|
||||
def get_key(
|
||||
password: str
|
||||
) -> bytes:
|
||||
"""Generates an encryption key based on the password provided."""
|
||||
key = SHA256.new(password.encode()).digest()
|
||||
return key
|
||||
|
||||
|
||||
def encrypt(key, source):
|
||||
def encrypt(
|
||||
key: bytes,
|
||||
source: bytes
|
||||
) -> bytes:
|
||||
"""Encrypts source data using the provided encryption key"""
|
||||
IV = Random.new().read(AES.block_size) # generate IV
|
||||
encryptor = AES.new(key, AES.MODE_CBC, IV)
|
||||
padding = AES.block_size - len(source) % AES.block_size # calculate needed padding
|
||||
@ -17,7 +24,10 @@ def encrypt(key, source):
|
||||
return data
|
||||
|
||||
|
||||
def decrypt(key, source):
|
||||
def decrypt(
|
||||
key: bytes,
|
||||
source: bytes
|
||||
) -> bytes:
|
||||
IV = source[: AES.block_size] # extract the IV from the beginning
|
||||
decryptor = AES.new(key, AES.MODE_CBC, IV)
|
||||
data = decryptor.decrypt(source[AES.block_size :]) # decrypt
|
||||
|
@ -109,9 +109,8 @@ class SimpleCSVLogger(FlaggingCallback):
|
||||
class CSVLogger(FlaggingCallback):
|
||||
"""
|
||||
The default implementation of the FlaggingCallback abstract class.
|
||||
Logs the input and output data to a CSV file.
|
||||
Logs the input and output data to a CSV file. Supports encryption.
|
||||
"""
|
||||
|
||||
def setup(self, flagging_dir: str):
|
||||
self.flagging_dir = flagging_dir
|
||||
os.makedirs(flagging_dir, exist_ok=True)
|
||||
|
@ -244,8 +244,8 @@ class Interface:
|
||||
self.description = description
|
||||
if article is not None:
|
||||
article = utils.readme_to_html(article)
|
||||
article = markdown2.markdown(article, extras=["fenced-code-blocks"])
|
||||
|
||||
article = markdown2.markdown(
|
||||
article, extras=["fenced-code-blocks"])
|
||||
self.article = article
|
||||
self.thumbnail = thumbnail
|
||||
|
||||
|
@ -1 +1 @@
|
||||
2.7.5
|
||||
2.7.5.2
|
2
setup.py
2
setup.py
@ -5,7 +5,7 @@ except ImportError:
|
||||
|
||||
setup(
|
||||
name="gradio",
|
||||
version="2.7.5",
|
||||
version="2.7.5.2",
|
||||
include_package_data=True,
|
||||
description="Python library for easily interacting with trained machine learning models",
|
||||
author="Abubakar Abid, Ali Abid, Ali Abdalla, Dawood Khan, Ahsen Khaliq",
|
||||
|
33
test/test_encryptor.py
Normal file
33
test/test_encryptor.py
Normal file
@ -0,0 +1,33 @@
|
||||
import os
|
||||
import unittest
|
||||
|
||||
from gradio import encryptor, processing_utils
|
||||
from gradio.test_data import BASE64_IMAGE
|
||||
|
||||
|
||||
os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"
|
||||
|
||||
|
||||
class TestKeyGenerator(unittest.TestCase):
|
||||
def test_same_pass(self):
|
||||
key1 = encryptor.get_key("test")
|
||||
key2 = encryptor.get_key("test")
|
||||
self.assertEquals(key1, key2)
|
||||
|
||||
def test_diff_pass(self):
|
||||
key1 = encryptor.get_key("test")
|
||||
key2 = encryptor.get_key("diff_test")
|
||||
self.assertNotEquals(key1, key2)
|
||||
|
||||
|
||||
class TestEncryptorDecryptor(unittest.TestCase):
|
||||
def test_same_pass(self):
|
||||
key = encryptor.get_key("test")
|
||||
data, _ = processing_utils.decode_base64_to_binary(BASE64_IMAGE)
|
||||
encrypted_data = encryptor.encrypt(key, data)
|
||||
decrypted_data = encryptor.decrypt(key, encrypted_data)
|
||||
self.assertEquals(data, decrypted_data)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
@ -1,12 +1,11 @@
|
||||
import tempfile
|
||||
import unittest
|
||||
import unittest.mock as mock
|
||||
|
||||
import gradio as gr
|
||||
from gradio import flagging
|
||||
|
||||
|
||||
class TestFlagging(unittest.TestCase):
|
||||
class TestDefaultFlagging(unittest.TestCase):
|
||||
def test_default_flagging_handler(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
io = gr.Interface(lambda x: x, "text", "text", flagging_dir=tmpdirname)
|
||||
@ -17,6 +16,8 @@ class TestFlagging(unittest.TestCase):
|
||||
self.assertEqual(row_count, 2) # 3 rows written including header
|
||||
io.close()
|
||||
|
||||
|
||||
class TestSimpleFlagging(unittest.TestCase):
|
||||
def test_simple_csv_flagging_handler(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
io = gr.Interface(
|
||||
@ -33,6 +34,6 @@ class TestFlagging(unittest.TestCase):
|
||||
self.assertEqual(row_count, 1) # no header
|
||||
io.close()
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
Loading…
x
Reference in New Issue
Block a user