cube-studio/myapp/views/view_train_model.py

255 lines
9.6 KiB
Python
Raw Normal View History

from flask import render_template,redirect
from flask_appbuilder.models.sqla.interface import SQLAInterface
from flask import Blueprint, current_app, jsonify, make_response, request
from myapp.models.model_serving import Service
from myapp.models.model_train_model import Training_Model
from myapp.models.model_serving import InferenceService
from myapp.models.model_team import Project,Project_User
from myapp.utils import core
from flask_babel import gettext as __
from flask_babel import lazy_gettext as _
from flask_appbuilder.actions import action
from myapp import app, appbuilder,db,event_logger
import logging
import re
import uuid
from myapp.views.view_team import Project_Filter,Project_Join_Filter,filter_join_org_project
import requests
from myapp.exceptions import MyappException
from flask_appbuilder.security.decorators import has_access
from myapp.models.model_job import Repository,Pipeline
from myapp.project import push_message,push_admin
from flask_wtf.file import FileAllowed, FileField, FileRequired
from werkzeug.datastructures import FileStorage
from wtforms.ext.sqlalchemy.fields import QuerySelectField
from myapp import security_manager
import os,sys
from wtforms.validators import DataRequired, Length, NumberRange, Optional,Regexp
from wtforms import BooleanField, IntegerField, SelectField, StringField,FloatField,DateField,DateTimeField,SelectMultipleField,FormField,FieldList
from flask_appbuilder.fieldwidgets import BS3TextFieldWidget,BS3PasswordFieldWidget,DatePickerWidget,DateTimePickerWidget,Select2ManyWidget,Select2Widget
from myapp.forms import MyBS3TextAreaFieldWidget,MySelect2Widget,MyCodeArea,MyLineSeparatedListField,MyJSONField,MyBS3TextFieldWidget,MySelectMultipleField
from myapp.utils.py import py_k8s
import os, zipfile
import shutil
from flask import (
current_app,
abort,
flash,
g,
Markup,
make_response,
redirect,
render_template,
request,
send_from_directory,
Response,
url_for,
)
from .base import (
DeleteMixin,
api,
BaseMyappView,
check_ownership,
data_payload_response,
DeleteMixin,
generate_download_headers,
get_error_msg,
get_user_roles,
handle_api_exception,
json_error_response,
json_success,
MyappFilter,
MyappModelView,
json_response
)
from sqlalchemy import and_, or_, select
from .baseApi import (
MyappModelRestApi
)
from flask_appbuilder import CompactCRUDMixin, expose
import pysnooper,datetime,time,json
conf = app.config
class Training_Model_Filter(MyappFilter):
# @pysnooper.snoop()
def apply(self, query, func):
user_roles = [role.name.lower() for role in list(self.get_user_roles())]
if "admin" in user_roles:
return query
return query.filter(self.model.created_by_fk == g.user.id)
class Training_Model_ModelView_Base():
datamodel = SQLAInterface(Training_Model)
2022-08-08 20:11:53 +08:00
base_permissions = ['can_add', 'can_edit', 'can_delete', 'can_list', 'can_show']
base_order = ('changed_on', 'desc')
order_columns = ['id']
list_columns = ['project_url','name','version','framework','api_type','pipeline_url','creator','modified','deploy']
2022-08-18 15:17:44 +08:00
search_columns = ['created_by','project','name','version','framework','api_type','pipeline_id','run_id','path']
add_columns = ['project','name','version','describe','path','framework','run_id','run_time','metrics','md5','api_type','pipeline_id']
edit_columns = add_columns
add_form_query_rel_fields = {
"project": [["name", Project_Join_Filter, 'org']]
}
edit_form_query_rel_fields = add_form_query_rel_fields
cols_width={
"name":{"type": "ellip2", "width": 250},
"project_url": {"type": "ellip2", "width": 200},
"pipeline_url":{"type": "ellip2", "width": 300},
"version": {"type": "ellip2", "width": 200},
"modified": {"type": "ellip2", "width": 150},
"deploy": {"type": "ellip2", "width": 100},
}
spec_label_columns = {
"path": "模型文件",
"framework":"算法框架",
"api_type":"推理框架",
2022-08-03 13:44:40 +08:00
"pipeline_id":"任务流id",
"deploy": "发布"
}
label_title = '模型'
2022-08-08 20:11:53 +08:00
base_filters = [["id", Training_Model_Filter, lambda: []]]
path_describe= r'''
tfserving仅支持tf save_model方式的模型目录, /mnt/xx/../saved_model/<br>
torch-servertorch-model-archiver编译后的mar模型文件地址, /mnt/xx/../xx.mar或torch script保存的模型<br>
onnxruntimeonnx模型文件的地址, /mnt/xx/../xx.onnx<br>
tensorrt:模型文件地址, /mnt/xx/../xx.plan<br>
'''
service_type_choices= [x.replace('_','-') for x in ['tfserving','torch-server','onnxruntime','triton-server']]
add_form_extra_fields={
"path": StringField(
_('模型文件地址'),
default='/mnt/admin/xx/saved_model/',
description=_(path_describe),
validators=[DataRequired()]
),
"describe": StringField(
_(datamodel.obj.lab('describe')),
description=_('模型描述'),
validators=[DataRequired()]
),
"pipeline_id": StringField(
_(datamodel.obj.lab('pipeline_id')),
description=_('任务流的id0表示非任务流产生模型'),
default='0'
),
"version": StringField(
_('版本'),
widget=MyBS3TextFieldWidget(),
description='模型版本',
default=datetime.datetime.now().strftime('v%Y.%m.%d.1'),
validators=[DataRequired()]
),
"run_id":StringField(
_(datamodel.obj.lab('run_id')),
widget=MyBS3TextFieldWidget(),
description='pipeline 训练的run id',
default='random_run_id_'+uuid.uuid4().hex[:32]
),
"run_time": StringField(
_(datamodel.obj.lab('run_time')),
widget=MyBS3TextFieldWidget(),
description='pipeline 训练的 运行时间',
default=datetime.datetime.now().strftime('%Y.%m.%d %H:%M:%S'),
),
"name":StringField(
_("模型名"),
widget=MyBS3TextFieldWidget(),
description='模型名(a-z0-9-字符组成最长54个字符)',
validators = [DataRequired(),Regexp("^[a-z0-9\-]*$"),Length(1,54)]
),
"framework": SelectField(
_('算法框架'),
description="选项xgb、tf、pytorch、onnx、tensorrt等",
widget=Select2Widget(),
choices=[['xgb', 'xgb'],['tf', 'tf'], ['pytorch', 'pytorch'],['onnx','onnx'],['tensorrt','tensorrt']],
validators=[DataRequired()]
),
'api_type': SelectField(
_("部署类型"),
description="推理框架类型",
choices=[[x, x] for x in service_type_choices],
validators=[DataRequired()]
)
}
edit_form_extra_fields=add_form_extra_fields
# edit_form_extra_fields['path']=FileField(
# _('模型压缩文件'),
# description=_(path_describe),
# validators=[
# FileAllowed(["zip",'tar.gz'],_("zip/tar.gz Files Only!")),
# ]
# )
# @pysnooper.snoop(watch_explode=('item'))
def pre_add(self,item):
if not item.run_id:
item.run_id='random_run_id_'+uuid.uuid4().hex[:32]
def pre_update(self,item):
if not item.path:
item.path=self.src_item_json['path']
self.pre_add(item)
@expose("/deploy/<model_id>", methods=["GET",'POST'])
def deploy(self,model_id):
train_model = db.session.query(Training_Model).filter_by(id=model_id).first()
exist_inference = db.session.query(InferenceService).filter_by(model_name=train_model.name).filter_by(model_version=train_model.version).first()
from myapp.views.view_inferenceserving import InferenceService_ModelView_base
inference_class = InferenceService_ModelView_base()
inference_class.src_item_json={}
if not exist_inference:
exist_inference = InferenceService()
exist_inference.project_id=train_model.project_id
exist_inference.project = train_model.project
exist_inference.model_name=train_model.name
exist_inference.label = train_model.describe
exist_inference.model_version=train_model.version
exist_inference.model_path=train_model.path
exist_inference.service_type=train_model.api_type
exist_inference.images=''
exist_inference.name='%s-%s-%s'%(exist_inference.service_type,train_model.name,train_model.version.replace('v','').replace('.',''))
inference_class.pre_add(exist_inference)
db.session.add(exist_inference)
db.session.commit()
flash('新服务版本创建完成','success')
else:
flash('服务版本已存在', 'success')
import urllib.parse
url = conf.get('MODEL_URLS',{}).get('inferenceservice','')+'?filter='+urllib.parse.quote(json.dumps([{"key":"model_name","value":exist_inference.model_name}],ensure_ascii=False))
print(url)
return redirect(url)
class Training_Model_ModelView(Training_Model_ModelView_Base,MyappModelView,DeleteMixin):
datamodel = SQLAInterface(Training_Model)
appbuilder.add_view(Training_Model_ModelView,"模型管理",icon = 'fa-hdd-o',category = '服务化',category_icon = 'fa-tasks')
2022-08-08 20:11:53 +08:00
class Training_Model_ModelView_Api(Training_Model_ModelView_Base,MyappModelRestApi): # noqa
datamodel = SQLAInterface(Training_Model)
# base_order = ('id', 'desc')
route_base = '/training_model_modelview/api'
appbuilder.add_api(Training_Model_ModelView_Api)