from flask import render_template,redirect from flask_appbuilder.models.sqla.interface import SQLAInterface from flask_appbuilder import ModelView, ModelRestApi from flask_appbuilder import ModelView,AppBuilder,expose,BaseView,has_access from importlib import reload from flask_babel import gettext as __ from flask_babel import lazy_gettext as _ # 将model添加成视图,并控制在前端的显示 import uuid from myapp.models.model_katib import Hyperparameter_Tuning from myapp.models.model_job import Repository from flask_appbuilder.actions import action from flask_appbuilder.models.sqla.filters import FilterEqualFunction, FilterStartsWith,FilterEqual,FilterNotEqual from wtforms.validators import EqualTo,Length from flask_babel import lazy_gettext,gettext from flask_appbuilder.security.decorators import has_access from flask_appbuilder.forms import GeneralModelConverter from myapp.utils import core from myapp import app, appbuilder,db,event_logger from wtforms.ext.sqlalchemy.fields import QuerySelectField import os,sys from wtforms.validators import DataRequired, Length, NumberRange, Optional,Regexp from sqlalchemy import and_, or_, select from myapp.exceptions import MyappException from wtforms import BooleanField, IntegerField, SelectField, StringField,FloatField,DateField,DateTimeField,SelectMultipleField,FormField,FieldList from flask_appbuilder.fieldwidgets import BS3TextFieldWidget,BS3PasswordFieldWidget,DatePickerWidget,DateTimePickerWidget,Select2ManyWidget,Select2Widget from myapp.forms import MyBS3TextAreaFieldWidget,MySelect2Widget,MyCodeArea,MyLineSeparatedListField,MyJSONField,MyBS3TextFieldWidget,MyCommaSeparatedListField,MySelectMultipleField from myapp.views.view_team import Project_Filter from myapp.utils.py import py_k8s from flask_wtf.file import FileField import shlex import re,copy from flask import ( current_app, abort, flash, g, Markup, make_response, redirect, render_template, request, send_from_directory, Response, url_for, ) from .baseApi import ( MyappModelRestApi ) from myapp import security_manager from werkzeug.datastructures import FileStorage from .base import ( api, BaseMyappView, check_ownership, data_payload_response, DeleteMixin, generate_download_headers, get_error_msg, get_user_roles, handle_api_exception, json_error_response, json_success, MyappFilter, MyappModelView, ) from flask_appbuilder import CompactCRUDMixin, expose import pysnooper,datetime,time,json from kubernetes.client import V1ObjectMeta import kubeflow.katib as kc from kubeflow.katib import constants from kubeflow.katib import utils from kubeflow.katib import V1alpha3AlgorithmSetting from kubeflow.katib import V1alpha3AlgorithmSetting from kubeflow.katib import V1alpha3AlgorithmSpec from kubeflow.katib import V1alpha3CollectorSpec from kubeflow.katib import V1alpha3EarlyStoppingSetting from kubeflow.katib import V1alpha3EarlyStoppingSpec from kubeflow.katib import V1alpha3Experiment from kubeflow.katib import V1alpha3ExperimentCondition from kubeflow.katib import V1alpha3ExperimentList from kubeflow.katib import V1alpha3ExperimentSpec from kubeflow.katib import V1alpha3ExperimentStatus from kubeflow.katib import V1alpha3FeasibleSpace from kubeflow.katib import V1alpha3FileSystemPath from kubeflow.katib import V1alpha3FilterSpec from kubeflow.katib import V1alpha3GoTemplate from kubeflow.katib import V1alpha3GraphConfig from kubeflow.katib import V1alpha3Metric from kubeflow.katib import V1alpha3MetricsCollectorSpec from kubeflow.katib import V1alpha3NasConfig from kubeflow.katib import V1alpha3ObjectiveSpec from kubeflow.katib import V1alpha3Observation from kubeflow.katib import V1alpha3Operation from kubeflow.katib import V1alpha3OptimalTrial from kubeflow.katib import V1alpha3ParameterAssignment from kubeflow.katib import V1alpha3ParameterSpec from kubeflow.katib import V1alpha3SourceSpec from kubeflow.katib import V1alpha3Suggestion from kubeflow.katib import V1alpha3SuggestionCondition from kubeflow.katib import V1alpha3SuggestionList from kubeflow.katib import V1alpha3SuggestionSpec from kubeflow.katib import V1alpha3SuggestionStatus from kubeflow.katib import V1alpha3TemplateSpec from kubeflow.katib import V1alpha3Trial from kubeflow.katib import V1alpha3TrialAssignment from kubeflow.katib import V1alpha3TrialCondition from kubeflow.katib import V1alpha3TrialList from kubeflow.katib import V1alpha3TrialSpec from kubeflow.katib import V1alpha3TrialStatus from kubeflow.katib import V1alpha3TrialTemplate conf = app.config class HP_Filter(MyappFilter): # @pysnooper.snoop() def apply(self, query, func): user_roles = [role.name.lower() for role in list(self.get_user_roles())] if "admin" in user_roles: return query.order_by(self.model.id.desc()) join_projects_id = security_manager.get_join_projects_id(db.session) # public_project_id = # logging.info(join_projects_id) return query.filter( or_( self.model.project_id.in_(join_projects_id), # self.model.project.name.in_(['public']) ) ).order_by(self.model.id.desc()) # 定义数据库视图 class Hyperparameter_Tuning_ModelView_Base(): datamodel = SQLAInterface(Hyperparameter_Tuning) conv = GeneralModelConverter(datamodel) label_title='超参搜索' check_redirect_list_url = '/hyperparameter_tuning_modelview/list/' help_url = conf.get('HELP_URL', {}).get(datamodel.obj.__tablename__, '') if datamodel else '' base_permissions = ['can_add', 'can_edit', 'can_delete', 'can_list', 'can_show'] # 默认为这些 base_order = ('id', 'desc') base_filters = [["id", HP_Filter, lambda: []]] # 设置权限过滤器 order_columns = ['id'] list_columns = ['project','name_url','describe','job_type','creator','run_url','modified'] show_columns = ['created_by','changed_by','created_on','changed_on','job_type','name','namespace','describe', 'parallel_trial_count','max_trial_count','max_failed_trial_count','objective_type', 'objective_goal','objective_metric_name','objective_additional_metric_names','algorithm_name', 'algorithm_setting','parameters_html','trial_spec_html','experiment_html'] add_form_query_rel_fields = { "project": [["name", Project_Filter, 'org']] } edit_form_query_rel_fields = add_form_query_rel_fields edit_form_extra_fields={} edit_form_extra_fields["alert_status"] = MySelectMultipleField( label=_(datamodel.obj.lab('alert_status')), widget=Select2ManyWidget(), choices=[[x, x] for x in ['Pending', 'Running', 'Succeeded', 'Failed', 'Unknown', 'Waiting', 'Terminated']], description="选择通知状态", ) edit_form_extra_fields['name'] = StringField( _(datamodel.obj.lab('name')), description='英文名(字母、数字、- 组成),最长50个字符', widget=BS3TextFieldWidget(), validators=[DataRequired(), Regexp("^[a-z][a-z0-9\-]*[a-z0-9]$"), Length(1, 54)] ) edit_form_extra_fields['describe'] = StringField( _(datamodel.obj.lab('describe')), description='中文描述', widget=BS3TextFieldWidget(), validators=[DataRequired()] ) edit_form_extra_fields['namespace'] = StringField( _(datamodel.obj.lab('namespace')), description='运行命名空间', widget=BS3TextFieldWidget(), default=datamodel.obj.namespace.default.arg, validators=[DataRequired()] ) edit_form_extra_fields['parallel_trial_count'] = IntegerField( _(datamodel.obj.lab('parallel_trial_count')), default=datamodel.obj.parallel_trial_count.default.arg, description='可并行的计算实例数目', widget=BS3TextFieldWidget(), validators=[DataRequired()] ) edit_form_extra_fields['max_trial_count'] = IntegerField( _(datamodel.obj.lab('max_trial_count')), default=datamodel.obj.max_trial_count.default.arg, description='最大并行的计算实例数目', widget=BS3TextFieldWidget(), validators=[DataRequired()] ) edit_form_extra_fields['max_failed_trial_count'] = IntegerField( _(datamodel.obj.lab('max_failed_trial_count')), default=datamodel.obj.max_failed_trial_count.default.arg, description='最大失败的计算实例数目', widget=BS3TextFieldWidget(), validators=[DataRequired()] ) edit_form_extra_fields['objective_type'] = SelectField( _(datamodel.obj.lab('objective_type')), default=datamodel.obj.objective_type.default.arg, description='目标函数类型(和自己代码中对应)', widget=Select2Widget(), choices=[['maximize', 'maximize'], ['minimize', 'minimize']], validators=[DataRequired()] ) edit_form_extra_fields['objective_goal'] = FloatField( _(datamodel.obj.lab('objective_goal')), default=datamodel.obj.objective_goal.default.arg, description='目标门限', widget=BS3TextFieldWidget(), validators=[DataRequired()] ) edit_form_extra_fields['objective_metric_name'] = StringField( _(datamodel.obj.lab('objective_metric_name')), default=datamodel.obj.objective_metric_name.default.arg, description='目标函数(和自己代码中对应)', widget=BS3TextFieldWidget(), validators=[DataRequired()] ) edit_form_extra_fields['objective_additional_metric_names'] = StringField( _(datamodel.obj.lab('objective_additional_metric_names')), default=datamodel.obj.objective_additional_metric_names.default.arg, description='其他目标函数(和自己代码中对应)', widget=BS3TextFieldWidget() ) algorithm_name_choices = ['grid', 'random', 'hyperband', 'bayesianoptimization'] algorithm_name_choices = [[algorithm_name_choice, algorithm_name_choice] for algorithm_name_choice in algorithm_name_choices] edit_form_extra_fields['algorithm_name'] = SelectField( _(datamodel.obj.lab('algorithm_name')), default=datamodel.obj.algorithm_name.default.arg, description='搜索算法', widget=Select2Widget(), choices=algorithm_name_choices, validators=[DataRequired()] ) edit_form_extra_fields['algorithm_setting'] = StringField( _(datamodel.obj.lab('algorithm_setting')), default=datamodel.obj.algorithm_setting.default.arg, widget=BS3TextFieldWidget(), description='搜索算法配置' ) edit_form_extra_fields['parameters_demo'] = StringField( _(datamodel.obj.lab('parameters_demo')), description='搜索参数示例,标准json格式,注意:所有整型、浮点型都写成字符串型', widget=MyCodeArea(code=core.hp_parameters_demo()), ) edit_form_extra_fields['parameters'] = StringField( _(datamodel.obj.lab('parameters')), default=datamodel.obj.parameters.default.arg, description='搜索参数,注意:所有整型、浮点型都写成字符串型', widget=MyBS3TextAreaFieldWidget(rows=10), validators=[DataRequired()] ) edit_form_extra_fields['node_selector'] = StringField( _(datamodel.obj.lab('node_selector')), description="部署task所在的机器(目前无需填写)", widget=BS3TextFieldWidget() ) edit_form_extra_fields['working_dir'] = StringField( _(datamodel.obj.lab('working_dir')), description="工作目录,如果为空,则使用Dockerfile中定义的workingdir", widget=BS3TextFieldWidget() ) edit_form_extra_fields['image_pull_policy'] = SelectField( _(datamodel.obj.lab('image_pull_policy')), description="镜像拉取策略(always为总是拉取远程镜像,IfNotPresent为若本地存在则使用本地镜像)", widget=Select2Widget(), choices=[['Always', 'Always'], ['IfNotPresent', 'IfNotPresent']] ) edit_form_extra_fields['volume_mount'] = StringField( _(datamodel.obj.lab('volume_mount')), description='外部挂载,格式:$pvc_name1(pvc):/$container_path1,$pvc_name2(pvc):/$container_path2', widget=BS3TextFieldWidget() ) edit_form_extra_fields['resource_memory'] = StringField( _(datamodel.obj.lab('resource_memory')), default=datamodel.obj.resource_memory.default.arg, description='内存的资源使用限制(每个测试实例),示例:1G,20G', widget=BS3TextFieldWidget(), validators=[DataRequired()] ) edit_form_extra_fields['resource_cpu'] = StringField( _(datamodel.obj.lab('resource_cpu')), default=datamodel.obj.resource_cpu.default.arg, description='cpu的资源使用限制(每个测试实例)(单位:核),示例:2', widget=BS3TextFieldWidget(), validators=[DataRequired()] ) # @pysnooper.snoop() def set_column(self, hp=None): # 对编辑进行处理 request_data = request.args.to_dict() job_type = request_data.get('job_type', '') if hp: job_type = hp.job_type job_type_choices = ['','TFJob','XGBoostJob','PyTorchJob','Job'] job_type_choices = [[job_type_choice,job_type_choice] for job_type_choice in job_type_choices] if hp: self.edit_form_extra_fields['job_type'] = SelectField( _(self.datamodel.obj.lab('job_type')), description="超参搜索的任务类型", choices=job_type_choices, widget=MySelect2Widget(extra_classes="readonly",value=job_type), validators=[DataRequired()] ) else: self.edit_form_extra_fields['job_type'] = SelectField( _(self.datamodel.obj.lab('job_type')), description="超参搜索的任务类型", widget=MySelect2Widget(new_web=True,value=job_type), choices=job_type_choices, validators=[DataRequired()] ) self.edit_form_extra_fields['tf_worker_num'] = IntegerField( _(self.datamodel.obj.lab('tf_worker_num')), default=json.loads(hp.job_json).get('tf_worker_num',3) if hp and hp.job_json else 3, description='工作节点数目', widget=BS3TextFieldWidget(), validators=[DataRequired()] ) self.edit_form_extra_fields['tf_worker_image'] = StringField( _(self.datamodel.obj.lab('tf_worker_image')), default=json.loads(hp.job_json).get('tf_worker_image',conf.get('KATIB_TFJOB_DEFAULT_IMAGE','')) if hp and hp.job_json else conf.get('KATIB_TFJOB_DEFAULT_IMAGE',''), description='工作节点镜像', widget=BS3TextFieldWidget(), validators=[DataRequired()] ) self.edit_form_extra_fields['tf_worker_command'] = StringField( _(self.datamodel.obj.lab('tf_worker_command')), default=json.loads(hp.job_json).get('tf_worker_command','python xx.py') if hp and hp.job_json else 'python xx.py', description='工作节点启动命令', widget=BS3TextFieldWidget(), validators=[DataRequired()] ) self.edit_form_extra_fields['job_worker_image'] = StringField( _(self.datamodel.obj.lab('job_worker_image')), default=json.loads(hp.job_json).get('job_worker_image',conf.get('KATIB_JOB_DEFAULT_IMAGE','')) if hp and hp.job_json else conf.get('KATIB_JOB_DEFAULT_IMAGE',''), description='工作节点镜像', widget=BS3TextFieldWidget(), validators=[DataRequired()] ) self.edit_form_extra_fields['job_worker_command'] = StringField( _(self.datamodel.obj.lab('job_worker_command')), default=json.loads(hp.job_json).get('job_worker_command','python xx.py') if hp and hp.job_json else 'python xx.py', description='工作节点启动命令', widget=MyBS3TextAreaFieldWidget(), validators=[DataRequired()] ) self.edit_form_extra_fields['pytorch_worker_num'] = IntegerField( _(self.datamodel.obj.lab('pytorch_worker_num')), default=json.loads(hp.job_json).get('pytorch_worker_num', 3) if hp and hp.job_json else 3, description='工作节点数目', widget=BS3TextFieldWidget(), validators=[DataRequired()] ) self.edit_form_extra_fields['pytorch_worker_image'] = StringField( _(self.datamodel.obj.lab('pytorch_worker_image')), default=json.loads(hp.job_json).get('pytorch_worker_image',conf.get('KATIB_PYTORCHJOB_DEFAULT_IMAGE','')) if hp and hp.job_json else conf.get('KATIB_PYTORCHJOB_DEFAULT_IMAGE',''), description='工作节点镜像', widget=BS3TextFieldWidget(), validators=[DataRequired()] ) self.edit_form_extra_fields['pytorch_master_command'] = StringField( _(self.datamodel.obj.lab('pytorch_master_command')), default=json.loads(hp.job_json).get('pytorch_master_command', 'python xx.py') if hp and hp.job_json else 'python xx.py', description='master节点启动命令', widget=BS3TextFieldWidget(), validators=[DataRequired()] ) self.edit_form_extra_fields['pytorch_worker_command'] = StringField( _(self.datamodel.obj.lab('pytorch_worker_command')), default=json.loads(hp.job_json).get('pytorch_worker_command', 'python xx.py') if hp and hp.job_json else 'python xx.py', description='工作节点启动命令', widget=BS3TextFieldWidget(), validators=[DataRequired()] ) self.edit_columns = ['job_type','project','name','namespace','describe','parallel_trial_count','max_trial_count','max_failed_trial_count', 'objective_type','objective_goal','objective_metric_name','objective_additional_metric_names', 'algorithm_name','algorithm_setting','parameters_demo', 'parameters'] self.edit_fieldsets=[( lazy_gettext('common'), {"fields": copy.deepcopy(self.edit_columns), "expanded": True}, )] if job_type=='TFJob': group_columns = ['tf_worker_num','tf_worker_image','tf_worker_command'] self.edit_fieldsets.append(( lazy_gettext(job_type), {"fields":group_columns, "expanded": True}, ) ) for column in group_columns: self.edit_columns.append(column) if job_type=='Job': group_columns = ['job_worker_image','job_worker_command'] self.edit_fieldsets.append(( lazy_gettext(job_type), {"fields":group_columns, "expanded": True}, ) ) for column in group_columns: self.edit_columns.append(column) if job_type=='PyTorchJob': group_columns = ['pytorch_worker_num','pytorch_worker_image','pytorch_master_command','pytorch_worker_command'] self.edit_fieldsets.append(( lazy_gettext(job_type), {"fields":group_columns, "expanded": True}, ) ) for column in group_columns: self.edit_columns.append(column) if job_type=='XGBoostJob': group_columns = ['pytorchjob_worker_image','pytorchjob_worker_command'] self.edit_fieldsets.append(( lazy_gettext(job_type), {"fields":group_columns, "expanded": True}, ) ) for column in group_columns: self.edit_columns.append(column) task_column=['working_dir','volume_mount','node_selector','image_pull_policy','resource_memory','resource_cpu'] self.edit_fieldsets.append(( lazy_gettext('task args'), {"fields": task_column, "expanded": True}, )) for column in task_column: self.edit_columns.append(column) self.edit_fieldsets.append(( lazy_gettext('run experiment'), {"fields": ['alert_status'], "expanded": True}, )) self.edit_columns.append('alert_status') self.add_form_extra_fields = self.edit_form_extra_fields self.add_fieldsets = self.edit_fieldsets self.add_columns=self.edit_columns # 处理form请求 def process_form(self, form, is_created): # from flask_appbuilder.forms import DynamicForm if 'parameters_demo' in form._fields: del form._fields['parameters_demo'] # 不处理这个字段 # 生成实验 # @pysnooper.snoop() def make_experiment(self,item): # 搜索算法相关 algorithmsettings = [] for setting in item.algorithm_setting.strip().split(','): setting = setting.strip() if setting: key,value = setting.split('=')[0].strip(),setting.split('=')[1].strip() algorithmsettings.append(V1alpha3AlgorithmSetting(name=key,value=value)) algorithm = V1alpha3AlgorithmSpec( algorithm_name=item.algorithm_name, algorithm_settings=algorithmsettings if algorithmsettings else None ) # 实验结果度量,很多中搜集方式,这里不应该写死这个。 metrics_collector_spec=None if item.job_type=='TFJob': collector = V1alpha3CollectorSpec(kind="TensorFlowEvent") source = V1alpha3SourceSpec(V1alpha3FileSystemPath(kind="Directory", path="/train")) metrics_collector_spec = V1alpha3MetricsCollectorSpec( collector=collector, source=source) elif item.job_type=='Job': pass # 目标函数 objective = V1alpha3ObjectiveSpec( goal=item.objective_goal, objective_metric_name=item.objective_metric_name, type=item.objective_type) # 搜索参数 parameters=[] hp_parameters = json.loads(item.parameters) for parameter in hp_parameters: if hp_parameters[parameter]['type']=='int' or hp_parameters[parameter]['type']=='double': feasible_space = V1alpha3FeasibleSpace( min=str(hp_parameters[parameter]['min']), max=str(hp_parameters[parameter]['max']), step = str(hp_parameters[parameter].get('step','')) if hp_parameters[parameter].get('step','') else None) parameters.append(V1alpha3ParameterSpec( feasible_space=feasible_space, name=parameter, parameter_type=hp_parameters[parameter]['type'] )) elif hp_parameters[parameter]['type']=='categorical': feasible_space = V1alpha3FeasibleSpace(list=hp_parameters[parameter]['list']) parameters.append(V1alpha3ParameterSpec( feasible_space=feasible_space, name=parameter, parameter_type=hp_parameters[parameter]['type'] )) # 实验模板 go_template = V1alpha3GoTemplate( raw_template=item.trial_spec ) trial_template = V1alpha3TrialTemplate(go_template=go_template) labels = { "run-rtx":g.user.username, "hp-name":item.name, # "hp-describe": item.describe } # Experiment 跑实例测试 experiment = V1alpha3Experiment( api_version= conf.get('CRD_INFO')['experiment']['group']+"/"+ conf.get('CRD_INFO')['experiment']['version'] ,#"kubeflow.org/v1alpha3", kind="Experiment", metadata=V1ObjectMeta(name=item.name+"-"+uuid.uuid4().hex[:4], namespace=conf.get('KATIB_NAMESPACE'),labels=labels), spec=V1alpha3ExperimentSpec( algorithm=algorithm, max_failed_trial_count=item.max_failed_trial_count, max_trial_count=item.max_trial_count, metrics_collector_spec=metrics_collector_spec, objective=objective, parallel_trial_count=item.parallel_trial_count, parameters=parameters, trial_template=trial_template ) ) item.experiment = json.dumps(experiment.to_dict(),indent=4,ensure_ascii=False) @expose('/create_experiment/',methods=['GET']) # @pysnooper.snoop(watch_explode=('hp',)) def create_experiment(self,id): hp = db.session.query(Hyperparameter_Tuning).filter(Hyperparameter_Tuning.id == int(id)).first() if hp: from myapp.utils.py.py_k8s import K8s k8s_client = K8s(hp.project.cluster['KUBECONFIG']) namespace = conf.get('KATIB_NAMESPACE') crd_info =conf.get('CRD_INFO')['experiment'] print(hp.experiment) k8s_client.create_crd(group=crd_info['group'],version=crd_info['version'],plural=crd_info['plural'],namespace=namespace,body=hp.experiment) flash('部署完成','success') # kclient = kc.KatibClient() # kclient.create_experiment(hp, namespace=conf.get('KATIB_NAMESPACE')) self.update_redirect() return redirect(self.get_redirect()) # @pysnooper.snoop(watch_explode=()) def merge_trial_spec(self,item): image_secrets = conf.get('HUBSECRET',[]) user_hubsecrets = db.session.query(Repository.hubsecret).filter(Repository.created_by_fk == g.user.id).all() if user_hubsecrets: for hubsecret in user_hubsecrets: if hubsecret[0] not in image_secrets: image_secrets.append(hubsecret[0]) image_secrets = [ { "name": hubsecret } for hubsecret in image_secrets ] item.job_json={} if item.job_type=='TFJob': item.trial_spec=core.merge_tfjob_experiment_template( worker_num=item.tf_worker_num, node_selector=item.node_selector, volume_mount=item.volume_mount, image=item.tf_worker_image, image_secrets = image_secrets, workingDir=item.working_dir, image_pull_policy=item.image_pull_policy, resource_memory=item.resource_memory, resource_cpu=item.resource_cpu, command=item.tf_worker_command ) item.job_json={ "tf_worker_num":item.tf_worker_num, "tf_worker_image": item.tf_worker_image, "tf_worker_command": item.tf_worker_command, } if item.job_type == 'Job': item.trial_spec=core.merge_job_experiment_template( node_selector=item.node_selector, volume_mount=item.volume_mount, image=item.job_worker_image, image_secrets=image_secrets, workingDir=item.working_dir, image_pull_policy=item.image_pull_policy, resource_memory=item.resource_memory, resource_cpu=item.resource_cpu, command=item.job_worker_command ) item.job_json = { "job_worker_image": item.job_worker_image, "job_worker_command": item.job_worker_command, } if item.job_type == 'PyTorchJob': item.trial_spec=core.merge_pytorchjob_experiment_template( worker_num=item.pytorch_worker_num, node_selector=item.node_selector, volume_mount=item.volume_mount, image=item.pytorch_worker_image, image_secrets=image_secrets, workingDir=item.working_dir, image_pull_policy=item.image_pull_policy, resource_memory=item.resource_memory, resource_cpu=item.resource_cpu, master_command=item.pytorch_master_command, worker_command=item.pytorch_worker_command ) item.job_json = { "pytorch_worker_num":item.pytorch_worker_num, "pytorch_worker_image": item.pytorch_worker_image, "pytorch_master_command": item.pytorch_master_command, "pytorch_worker_command": item.pytorch_worker_command, } item.job_json = json.dumps(item.job_json,indent=4,ensure_ascii=False) # 检验参数是否有效 # @pysnooper.snoop() def validate_parameters(self,parameters,algorithm): try: parameters = json.loads(parameters) for parameter_name in parameters: parameter = parameters[parameter_name] if parameter['type'] == 'int' and 'min' in parameter and 'max' in parameter: parameter['min'] = int(parameter['min']) parameter['max'] = int(parameter['max']) if not parameter['max']>parameter['min']: raise Exception('min must lower than max') continue if parameter['type'] == 'double' and 'min' in parameter and 'max' in parameter: parameter['min'] = float(parameter['min']) parameter['max'] = float(parameter['max']) if not parameter['max']>parameter['min']: raise Exception('min must lower than max') if algorithm=='grid': parameter['step'] = float(parameter['step']) continue if parameter['type']=='categorical' and 'list' in parameter and type(parameter['list'])==list: continue raise MyappException('parameters type must in [int,double,categorical], and min\max\step\list should exist, and min must lower than max ') return json.dumps(parameters,indent=4,ensure_ascii=False) except Exception as e: print(e) raise MyappException('parameters not valid:'+str(e)) # @pysnooper.snoop() def pre_add(self, item): if item.job_type is None: raise MyappException("Job type is mandatory") core.validate_json(item.parameters) item.parameters = self.validate_parameters(item.parameters,item.algorithm_name) item.resource_memory=core.check_resource_memory(item.resource_memory,self.src_item_json.get('resource_memory',None) if self.src_item_json else None) item.resource_cpu = core.check_resource_cpu(item.resource_cpu,self.src_item_json.get('resource_cpu',None) if self.src_item_json else None) self.merge_trial_spec(item) self.make_experiment(item) def pre_update(self, item): self.pre_add(item) pre_add_get=set_column pre_update_get=set_column @action( "copy", __("Copy Hyperparameter Experiment"), confirmation=__('Copy Hyperparameter Experiment'), icon="fa-copy",multiple=True, single=False ) def copy(self, hps): if not isinstance(hps, list): hps = [hps] for hp in hps: new_hp = hp.clone() new_hp.name = new_hp.name+"-copy" new_hp.describe = new_hp.describe + "-copy" new_hp.created_on = datetime.datetime.now() new_hp.changed_on = datetime.datetime.now() db.session.add(new_hp) db.session.commit() return redirect(request.referrer) class Hyperparameter_Tuning_ModelView(Hyperparameter_Tuning_ModelView_Base,MyappModelView): datamodel = SQLAInterface(Hyperparameter_Tuning) conv = GeneralModelConverter(datamodel) # 添加视图和菜单 appbuilder.add_view(Hyperparameter_Tuning_ModelView,"katib超参搜索",icon = 'fa-shopping-basket',category = '超参搜索',category_icon = 'fa-glass') # 添加api class Hyperparameter_Tuning_ModelView_Api(Hyperparameter_Tuning_ModelView_Base,MyappModelRestApi): datamodel = SQLAInterface(Hyperparameter_Tuning) conv = GeneralModelConverter(datamodel) route_base = '/hyperparameter_tuning_modelview/api' list_columns = ['created_by','changed_by','created_on','changed_on','job_type','name','namespace','describe', 'parallel_trial_count','max_trial_count','max_failed_trial_count','objective_type', 'objective_goal','objective_metric_name','objective_additional_metric_names','algorithm_name', 'algorithm_setting','parameters','job_json','trial_spec','working_dir','node_selector', 'image_pull_policy','resource_memory','resource_cpu','experiment','alert_status'] add_columns = ['job_type','name','namespace','describe', 'parallel_trial_count','max_trial_count','max_failed_trial_count','objective_type', 'objective_goal','objective_metric_name','objective_additional_metric_names','algorithm_name', 'algorithm_setting','parameters','job_json','working_dir','node_selector','image_pull_policy', 'resource_memory','resource_cpu'] edit_columns = add_columns appbuilder.add_api(Hyperparameter_Tuning_ModelView_Api) # list正在运行的Experiments from myapp.views.view_workflow import Crd_ModelView_Base from myapp.models.model_katib import Experiments class Experiments_ModelView(Crd_ModelView_Base,MyappModelView,DeleteMixin): label_title='超参调度' datamodel = SQLAInterface(Experiments) list_columns = ['url','namespace_url','create_time','status','username'] crd_name = 'experiment' appbuilder.add_view(Experiments_ModelView,"katib超参调度",icon = 'fa-tasks',category = '超参搜索') # 添加api class Experiments_ModelView_Api(Crd_ModelView_Base,MyappModelRestApi): datamodel = SQLAInterface(Experiments) route_base = '/experiments_modelview/api' list_columns = ['url', 'namespace_url', 'create_time', 'status', 'username'] crd_name = 'experiment' appbuilder.add_api(Experiments_ModelView_Api)