1、from mindspore.train.callback import ModelCheckpoint, CheckpointConfig# 设置模型保存参数config_ck = CheckpointConfig(save_checkpoint_steps=1875, keep_checkpoint_max=10)# 应用模型保存参数ckpoint = ModelCheckpoint(prefix="checkpoint_lenet", config=config_ck)
2、通过MindSpore提供的model.train接口可以方便地进行网络的训练,LossMonitor可以监控训练过程中loss值的变化。# 导入模型训练需要的库from mindspore.nn import Accuracyfrom mindspore.train.callback import LossMonitorfrom mindspore import Model
3、def train_net(args, model, epoch_size, data_path, repeat_size, ckpoint_cb, sink_mode): """定义训练的方法""" # 加载训练数据集 ds_train = create_dataset(os.path.join(data_path, "train"), 32, repeat_size) model.train(epoch_size, ds_train, callbacks=[ckpoint_cb, LossMonitor(125)], dataset_sink_mode=sink_mode)
4、其中,dataset_sink_mode用于控制数据是否下沉,数据下沉是指数据通过通道直接传送到Device上,可以加快训练速度,dataset_sink_mode为True表示数据下沉,否则为非下沉。通过模型运行测试数据集得到的结果,验证模型的泛化能力。使用model.eval接口读入测试数据集。使用保存后的模型参数进行推理。
5、def test_net(network, model, data_path): """定义验证的方法""" ds_eval = create_dataset(os.path.join(data_path, "test")) acc = model.eval(ds_eval, dataset_sink_mode=False) print("{}".format(acc))
6、train_epoch = 1mnist_path = "./datasets/MNIST_Data"dataset_size = 1model = Model(net, net_loss, net_opt, metrics={"Accuracy": Accuracy()})train_net(args, model, train_epoch, mnist_path, dataset_size, ckpoint, False)test_net(net, model, mnist_path)