mirror of
https://github.com/gradio-app/gradio.git
synced 2025-01-06 10:25:17 +08:00
flag now outputs rebuilt file in source folder
This commit is contained in:
parent
16097af9f9
commit
ed789d1e4e
1
.gitignore
vendored
1
.gitignore
vendored
@ -10,3 +10,4 @@ models/*
|
||||
gradio_files/*
|
||||
ngrok*
|
||||
examples/ngrok*
|
||||
gradio-flagged/*
|
@ -87,7 +87,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 10,
|
||||
"execution_count": 5,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@ -97,23 +97,22 @@
|
||||
"io = gradio.Interface(inputs=inp, \n",
|
||||
" outputs=out,\n",
|
||||
" model=model, \n",
|
||||
" model_type='keras')"
|
||||
" model_type='keras')\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 11,
|
||||
"metadata": {
|
||||
"scrolled": false
|
||||
},
|
||||
"execution_count": 20,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Closing existing server...\n",
|
||||
"NOTE: Gradio is in beta stage, please report all bugs to: contact.gradio@gmail.com\n",
|
||||
"Model is running locally at: http://localhost:7863/\n",
|
||||
"Model available publicly for 8 hours at: https://bbf127b1.gradio.app/\n"
|
||||
"Model is running locally at: http://localhost:7860/\n",
|
||||
"Unable to create public link for interface, please check internet connection or try restarting python interpreter.\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -123,14 +122,14 @@
|
||||
" <iframe\n",
|
||||
" width=\"1000\"\n",
|
||||
" height=\"500\"\n",
|
||||
" src=\"http://localhost:7863/\"\n",
|
||||
" src=\"http://localhost:7860/\"\n",
|
||||
" frameborder=\"0\"\n",
|
||||
" allowfullscreen\n",
|
||||
" ></iframe>\n",
|
||||
" "
|
||||
],
|
||||
"text/plain": [
|
||||
"<IPython.lib.display.IFrame at 0x1fda5796f98>"
|
||||
"<IPython.lib.display.IFrame at 0x19e9d6b37b8>"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
|
File diff suppressed because one or more lines are too long
@ -8,6 +8,7 @@ from abc import ABC, abstractmethod
|
||||
from gradio import preprocessing_utils, validation_data
|
||||
import numpy as np
|
||||
from PIL import Image, ImageOps
|
||||
import datetime
|
||||
|
||||
# Where to find the static resources associated with each template.
|
||||
BASE_INPUT_INTERFACE_TEMPLATE_PATH = 'templates/input/{}.html'
|
||||
@ -63,6 +64,12 @@ class AbstractInput(ABC):
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def rebuild_flagged(self, inp):
|
||||
"""
|
||||
All interfaces should define a method that rebuilds the flagged input when it's passed back (i.e. rebuilds image from base64)
|
||||
"""
|
||||
pass
|
||||
|
||||
class Sketchpad(AbstractInput):
|
||||
def __init__(self, preprocessing_fn=None, shape=(28, 28), invert_colors=True, flatten=False, scale=1, shift=0,
|
||||
@ -95,7 +102,15 @@ class Sketchpad(AbstractInput):
|
||||
array = array * self.scale + self.shift
|
||||
array = array.astype(self.dtype)
|
||||
return array
|
||||
|
||||
def rebuild_flagged(self, inp):
|
||||
"""
|
||||
Default rebuild method to decode a base64 image
|
||||
"""
|
||||
im = preprocessing_utils.encoding_to_image(inp)
|
||||
timestamp = datetime.datetime.now()
|
||||
im.save(f'gradio-flagged/{timestamp.strftime("%Y-%m-%d %H-%M-%S")}.png', 'PNG')
|
||||
return None
|
||||
|
||||
|
||||
class Webcam(AbstractInput):
|
||||
def __init__(self, preprocessing_fn=None, image_width=224, image_height=224, num_channels=3):
|
||||
@ -119,6 +134,14 @@ class Webcam(AbstractInput):
|
||||
im = preprocessing_utils.resize_and_crop(im, (self.image_width, self.image_height))
|
||||
array = np.array(im).flatten().reshape(1, self.image_width, self.image_height, self.num_channels)
|
||||
return array
|
||||
def rebuild_flagged(self, inp):
|
||||
"""
|
||||
Default rebuild method to decode a base64 image
|
||||
"""
|
||||
im = preprocessing_utils.encoding_to_image(inp)
|
||||
timestamp = datetime.datetime.now()
|
||||
im.save(f'gradio-flagged/{timestamp.strftime("%Y-%m-%d %H-%M-%S")}.png', 'PNG')
|
||||
return None
|
||||
|
||||
|
||||
class Textbox(AbstractInput):
|
||||
@ -133,7 +156,15 @@ class Textbox(AbstractInput):
|
||||
By default, no pre-processing is applied to text.
|
||||
"""
|
||||
return inp
|
||||
|
||||
def rebuild_flagged(self, inp):
|
||||
"""
|
||||
Default rebuild method for text saves it .txt file
|
||||
"""
|
||||
timestamp = datetime.datetime.now()
|
||||
f = open(f'gradio-flagged/{timestamp.strftime("%Y-%m-%d %H-%M-%S")}.txt','w')
|
||||
f.write(inp)
|
||||
f.close()
|
||||
return None
|
||||
|
||||
class ImageUpload(AbstractInput):
|
||||
def __init__(self, preprocessing_fn=None, shape=(224, 224, 3), image_mode='RGB',
|
||||
@ -171,6 +202,14 @@ class ImageUpload(AbstractInput):
|
||||
array = im.reshape(1, self.image_width, self.image_height, self.num_channels)
|
||||
return array
|
||||
|
||||
def rebuild_flagged(self, inp):
|
||||
"""
|
||||
Default rebuild method to decode a base64 image
|
||||
"""
|
||||
im = preprocessing_utils.encoding_to_image(inp)
|
||||
timestamp = datetime.datetime.now()
|
||||
im.save(f'gradio-flagged/{timestamp.strftime("%Y-%m-%d %H-%M-%S")}.png', 'PNG')
|
||||
return None
|
||||
|
||||
class CSV(AbstractInput):
|
||||
|
||||
|
@ -9,12 +9,15 @@ import nest_asyncio
|
||||
import webbrowser
|
||||
import gradio.inputs
|
||||
import gradio.outputs
|
||||
from gradio import networking, strings
|
||||
from gradio import networking, strings, inputs
|
||||
import tempfile
|
||||
import threading
|
||||
import traceback
|
||||
import urllib
|
||||
import json
|
||||
import os
|
||||
import errno
|
||||
|
||||
|
||||
nest_asyncio.apply()
|
||||
|
||||
@ -123,10 +126,17 @@ class Interface:
|
||||
}
|
||||
await websocket.send(json.dumps(output))
|
||||
if msg['action'] == 'flag':
|
||||
print('flagged')
|
||||
f = open('gradio-flagged.txt','a+')
|
||||
if not os.path.exists(os.path.dirname('gradio-flagged/')):
|
||||
try:
|
||||
os.makedirs(os.path.dirname('gradio-flagged/'))
|
||||
except OSError as exc: # Guard against race condition
|
||||
if exc.errno != errno.EEXIST:
|
||||
raise
|
||||
f = open('gradio-flagged/gradio-flagged.txt','a+')
|
||||
f.write(str(msg['data']))
|
||||
f.close()
|
||||
inp = msg['data']['input']
|
||||
self.input_interface.rebuild_flagged(inp)
|
||||
|
||||
except websockets.exceptions.ConnectionClosed:
|
||||
pass
|
||||
|
Loading…
Reference in New Issue
Block a user