from flask import render_template,redirect from flask_appbuilder.models.sqla.interface import SQLAInterface from myapp.models.model_train_model import Training_Model from myapp.models.model_serving import InferenceService from flask_babel import lazy_gettext as _ from myapp import app, appbuilder,db,event_logger import logging import re import uuid from myapp.views.view_team import Project_Filter,Project_Join_Filter,filter_join_org_project from wtforms.validators import DataRequired, Length, NumberRange, Optional,Regexp from wtforms import BooleanField, IntegerField, SelectField, StringField from flask_appbuilder.fieldwidgets import Select2Widget from myapp.forms import MyBS3TextFieldWidget,MySelectMultipleField from flask import ( current_app, abort, flash, g, Markup, make_response, redirect ) from .base import ( DeleteMixin, DeleteMixin, MyappFilter, MyappModelView, json_response ) from .baseApi import ( MyappModelRestApi ) from flask_appbuilder import CompactCRUDMixin, expose import pysnooper,datetime,time,json conf = app.config class Training_Model_Filter(MyappFilter): # @pysnooper.snoop() def apply(self, query, func): user_roles = [role.name.lower() for role in list(self.get_user_roles())] if "admin" in user_roles: return query return query.filter(self.model.created_by_fk == g.user.id) class Training_Model_ModelView_Base(): datamodel = SQLAInterface(Training_Model) base_permissions = ['can_add', 'can_edit', 'can_delete', 'can_list', 'can_show'] base_order = ('changed_on', 'desc') order_columns = ['id'] list_columns = ['project_url','name','version','framework','api_type','pipeline_url','creator','modified','deploy'] search_columns = ['created_by','project','name','version','framework','api_type','pipeline_id','run_id','path'] add_columns = ['project','name','version','describe','path','framework','run_id','run_time','metrics','md5','api_type','pipeline_id'] edit_columns = add_columns add_form_query_rel_fields = { "project": [["name", Project_Join_Filter, 'org']] } edit_form_query_rel_fields = add_form_query_rel_fields cols_width={ "name":{"type": "ellip2", "width": 250}, "project_url": {"type": "ellip2", "width": 200}, "pipeline_url":{"type": "ellip2", "width": 300}, "version": {"type": "ellip2", "width": 200}, "modified": {"type": "ellip2", "width": 150}, "deploy": {"type": "ellip2", "width": 100}, } spec_label_columns = { "path": "模型文件", "framework":"算法框架", "api_type":"推理框架", "pipeline_id":"任务流id", "deploy": "发布" } label_title = '模型' base_filters = [["id", Training_Model_Filter, lambda: []]] path_describe= r''' tfserving:仅支持tf save_model方式的模型目录, /mnt/xx/../saved_model/
torch-server:torch-model-archiver编译后的mar模型文件地址, /mnt/xx/../xx.mar或torch script保存的模型
onnxruntime:onnx模型文件的地址, /mnt/xx/../xx.onnx
tensorrt:模型文件地址, /mnt/xx/../xx.plan
''' service_type_choices= [x.replace('_','-') for x in ['tfserving','torch-server','onnxruntime','triton-server']] add_form_extra_fields={ "path": StringField( _('模型文件地址'), default='/mnt/admin/xx/saved_model/', description=_(path_describe), validators=[DataRequired()] ), "describe": StringField( _(datamodel.obj.lab('describe')), description=_('模型描述'), validators=[DataRequired()] ), "pipeline_id": StringField( _(datamodel.obj.lab('pipeline_id')), description=_('任务流的id,0表示非任务流产生模型'), default='0' ), "version": StringField( _('版本'), widget=MyBS3TextFieldWidget(), description='模型版本', default=datetime.datetime.now().strftime('v%Y.%m.%d.1'), validators=[DataRequired()] ), "run_id":StringField( _(datamodel.obj.lab('run_id')), widget=MyBS3TextFieldWidget(), description='pipeline 训练的run id', default='random_run_id_'+uuid.uuid4().hex[:32] ), "run_time": StringField( _(datamodel.obj.lab('run_time')), widget=MyBS3TextFieldWidget(), description='pipeline 训练的 运行时间', default=datetime.datetime.now().strftime('%Y.%m.%d %H:%M:%S'), ), "name":StringField( _("模型名"), widget=MyBS3TextFieldWidget(), description='模型名(a-z0-9-字符组成,最长54个字符)', validators = [DataRequired(),Regexp("^[a-z0-9\-]*$"),Length(1,54)] ), "framework": SelectField( _('算法框架'), description="选项xgb、tf、pytorch、onnx、tensorrt等", widget=Select2Widget(), choices=[['xgb', 'xgb'],['tf', 'tf'], ['pytorch', 'pytorch'],['onnx','onnx'],['tensorrt','tensorrt']], validators=[DataRequired()] ), 'api_type': SelectField( _("部署类型"), description="推理框架类型", choices=[[x, x] for x in service_type_choices], validators=[DataRequired()] ) } edit_form_extra_fields=add_form_extra_fields # edit_form_extra_fields['path']=FileField( # _('模型压缩文件'), # description=_(path_describe), # validators=[ # FileAllowed(["zip",'tar.gz'],_("zip/tar.gz Files Only!")), # ] # ) # @pysnooper.snoop(watch_explode=('item')) def pre_add(self,item): if not item.run_id: item.run_id='random_run_id_'+uuid.uuid4().hex[:32] def pre_update(self,item): if not item.path: item.path=self.src_item_json['path'] self.pre_add(item) @expose("/deploy/", methods=["GET",'POST']) def deploy(self,model_id): train_model = db.session.query(Training_Model).filter_by(id=model_id).first() exist_inference = db.session.query(InferenceService).filter_by(model_name=train_model.name).filter_by(model_version=train_model.version).first() from myapp.views.view_inferenceserving import InferenceService_ModelView_base inference_class = InferenceService_ModelView_base() inference_class.src_item_json={} if not exist_inference: exist_inference = InferenceService() exist_inference.project_id=train_model.project_id exist_inference.project = train_model.project exist_inference.model_name=train_model.name exist_inference.label = train_model.describe exist_inference.model_version=train_model.version exist_inference.model_path=train_model.path exist_inference.service_type=train_model.api_type exist_inference.images='' exist_inference.name='%s-%s-%s'%(exist_inference.service_type,train_model.name,train_model.version.replace('v','').replace('.','')) inference_class.pre_add(exist_inference) db.session.add(exist_inference) db.session.commit() flash('新服务版本创建完成','success') else: flash('服务版本已存在', 'success') import urllib.parse url = conf.get('MODEL_URLS',{}).get('inferenceservice','')+'?filter='+urllib.parse.quote(json.dumps([{"key":"model_name","value":exist_inference.model_name}],ensure_ascii=False)) print(url) return redirect(url) class Training_Model_ModelView(Training_Model_ModelView_Base,MyappModelView,DeleteMixin): datamodel = SQLAInterface(Training_Model) appbuilder.add_view(Training_Model_ModelView,"模型管理",icon = 'fa-hdd-o',category = '服务化',category_icon = 'fa-tasks') class Training_Model_ModelView_Api(Training_Model_ModelView_Base,MyappModelRestApi): # noqa datamodel = SQLAInterface(Training_Model) # base_order = ('id', 'desc') route_base = '/training_model_modelview/api' appbuilder.add_api(Training_Model_ModelView_Api)