mirror of
https://github.com/tencentmusic/cube-studio.git
synced 2025-02-17 14:40:28 +08:00
修正数据集备份功能
This commit is contained in:
parent
5d32d32612
commit
636b3fc5d4
@ -277,24 +277,48 @@ def update_dataset(task,dataset_id):
|
||||
with session_scope(nullpool=True) as dbsession:
|
||||
try:
|
||||
dataset = dbsession.query(Dataset).filter_by(id=dataset_id).first()
|
||||
store_type = conf.get('STORE_TYPE', 'minio')
|
||||
params = importlib.import_module(f'myapp.utils.store.{store_type}')
|
||||
store_client = getattr(params, store_type.upper() + '_client')(**conf.get('STORE_CONFIG', {}))
|
||||
|
||||
remote_dir = f'dataset/{dataset.name}/{dataset.version if dataset.version else "latest"}/'
|
||||
remote_dir = os.path.join('/data/k8s/kubeflow/global/', remote_dir)
|
||||
if os.path.exists(remote_dir):
|
||||
# 先清理干净,因为有可能存在旧的不对的数据
|
||||
import shutil
|
||||
shutil.rmtree(remote_dir, ignore_errors=True)
|
||||
os.makedirs(remote_dir, exist_ok=True)
|
||||
|
||||
# 备份在本地
|
||||
if dataset.path:
|
||||
paths = dataset.path.split("\n")
|
||||
for path in paths:
|
||||
file_name = path[path.rindex("/") + 1:]
|
||||
local_path = os.path.join('/home/myapp/myapp/static/', path.lstrip('/'))
|
||||
store_client.uploadfile(local_path,remote_file_path=f'/dataset/{dataset.name}/{dataset.version if dataset.version else "latest"}/{dataset.subdataset}/{dataset.segment if dataset.segment else "0"}/{file_name}')
|
||||
if os.path.exists(local_path):
|
||||
# 对文件直接复制
|
||||
if os.path.isfile(local_path):
|
||||
shutil.copy(local_path,remote_dir)
|
||||
# 对文件夹要拷贝文件夹
|
||||
if os.path.isdir(local_path):
|
||||
shutil.copytree(local_path,remote_dir)
|
||||
|
||||
elif dataset.download_url:
|
||||
download_urls = dataset.download_url.split("\n")
|
||||
for download_url in download_urls:
|
||||
file_name = download_url[download_url.rindex("/") + 1:]
|
||||
store_client.uploadfile(download_url,remote_file_path=f'/dataset/{dataset.name}/{dataset.version if dataset.version else "latest"}/{dataset.subdataset}/{dataset.segment if dataset.segment else "0"}/{file_name}')
|
||||
try:
|
||||
import requests
|
||||
filename = download_url.split("/")[-1]
|
||||
try_num=0
|
||||
while try_num<3:
|
||||
try_num+=1
|
||||
response = requests.get(download_url)
|
||||
with open(remote_dir + '/' + filename, 'wb') as f:
|
||||
f.write(response.content)
|
||||
break
|
||||
except Exception as e:
|
||||
print(e)
|
||||
|
||||
except Exception as e:
|
||||
logging.error(e)
|
||||
push_admin(f'数据集备份失败,id:{dataset_id}')
|
||||
|
||||
|
||||
if __name__ =='__main__':
|
||||
|
Loading…
Reference in New Issue
Block a user