200字范文,内容丰富有趣,生活中的好帮手!
200字范文 > Windows7 64bit VS Caffe test MNIST操作步骤

Windows7 64bit VS Caffe test MNIST操作步骤

时间:2022-06-01 15:12:29

相关推荐

Windows7 64bit VS Caffe test MNIST操作步骤

在/fengbingchun/article/details/49849225中用Caffe对MNIST数据库进行训练,产生了model。下面介绍下如何将产生的model应用在实际的数字图像识别中。

用到的测试图像与/fengbingchun/article/details/50573841中相同,总共10幅,如下:

在test时与train时的prototxt文件若有不同,test时的prototxt文件修改为如下:

name: "LeNet"layer {name: "data"type: "MemoryData"top: "data"top: "label"memory_data_param {batch_size: 1channels: 1height: 28width: 28}transform_param {scale: 0.00390625}}layer {name: "conv1"type: "Convolution"bottom: "data"top: "conv1"param {lr_mult: 1}param {lr_mult: 2}convolution_param {num_output: 20kernel_size: 5stride: 1weight_filler {type: "xavier"}bias_filler {type: "constant"}}}layer {name: "pool1"type: "Pooling"bottom: "conv1"top: "pool1"pooling_param {pool: MAXkernel_size: 2stride: 2}}layer {name: "conv2"type: "Convolution"bottom: "pool1"top: "conv2"param {lr_mult: 1}param {lr_mult: 2}convolution_param {num_output: 50kernel_size: 5stride: 1weight_filler {type: "xavier"}bias_filler {type: "constant"}}}layer {name: "pool2"type: "Pooling"bottom: "conv2"top: "pool2"pooling_param {pool: MAXkernel_size: 2stride: 2}}layer {name: "ip1"type: "InnerProduct"bottom: "pool2"top: "ip1"param {lr_mult: 1}param {lr_mult: 2}inner_product_param {num_output: 500weight_filler {type: "xavier"}bias_filler {type: "constant"}}}layer {name: "relu1"type: "ReLU"bottom: "ip1"top: "ip1"}layer {name: "ip2"type: "InnerProduct"bottom: "ip1"top: "ip2"param {lr_mult: 1}param {lr_mult: 2}inner_product_param {num_output: 10weight_filler {type: "xavier"}bias_filler {type: "constant"}}}layer {name: "prob"type: "Softmax"bottom: "ip2"top: "prob"}

测试代码如下:

#include <iostream>#include <glog/logging.h>#include <cstring>#include <map>#include <string>#include <vector>#include "boost/algorithm/string.hpp"#include "caffe/caffe.hpp"#include "caffe/util/io.hpp"#include "caffe/blob.hpp"#include <opencv2/core/core.hpp>#include <opencv2/highgui/highgui.hpp>#include <opencv2/imgproc/imgproc.hpp>using caffe::Blob;using caffe::Caffe;using caffe::Net;using caffe::Layer;using caffe::shared_ptr;using caffe::string;using caffe::Timer;using caffe::vector;using std::ostringstream;DEFINE_string(model, "E:/GitCode/Caffe_Test/test_data/model/mnist/lenet_train_test_.prototxt","The model definition protocol buffer text file..");DEFINE_string(weights, "E:/GitCode/Caffe_Test/test_data/model/mnist/lenet_iter_10000.caffemodel","Optional; the pretrained weights to initialize finetuning, ""separated by ','. Cannot be set simultaneously with snapshot.");// A simple registry for caffe commands.typedef int(*BrewFunction)();typedef std::map<caffe::string, BrewFunction> BrewMap;BrewMap g_brew_map;#define RegisterBrewFunction(func) \namespace { \class __Registerer_##func { \public: /* NOLINT */ \__Registerer_##func() { \g_brew_map[#func] = &func; \} \}; \__Registerer_##func g_registerer_##func; \}static BrewFunction GetBrewFunction(const caffe::string& name) {if (g_brew_map.count(name)) {return g_brew_map[name];}else {LOG(ERROR) << "Available caffe actions:";for (BrewMap::iterator it = g_brew_map.begin();it != g_brew_map.end(); ++it) {LOG(ERROR) << "\t" << it->first;}LOG(FATAL) << "Unknown action: " << name;return NULL; // not reachable, just to suppress old compiler warnings.}}// caffe commands to call by//caffe <command> <args>//// To add a command, define a function "int command()" and register it with// RegisterBrewFunction(action);// Test: score a model.int test() {CHECK_GT(FLAGS_model.size(), 0) << "Need a model definition to score.";CHECK_GT(FLAGS_weights.size(), 0) << "Need model weights to score.";LOG(INFO) << "Use CPU.";Caffe::set_mode(Caffe::CPU);// Instantiate the caffe <float> caffe_net(FLAGS_model, caffe::TEST);caffe_net.CopyTrainedLayersFrom(FLAGS_weights);int target[10] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9};int result[10] = { -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 };std::string image_path = "E:/GitCode/Caffe_Test/test_data/images/";for (int i = 0; i < 10; i++) {char ch[15];sprintf(ch, "%d", i);std::string str;str = std::string(ch);str += ".png";str = image_path + str;cv::Mat mat = cv::imread(str.c_str(), 1);if (!mat.data) {std::cout << "load image error" << std::endl;return -1;}cv::cvtColor(mat, mat, CV_BGR2GRAY);cv::resize(mat, mat, cv::Size(28, 28));cv::bitwise_not(mat, mat);// set the patch for testingvector<cv::Mat> patches;patches.push_back(mat);// push vector<Mat> to data layerfloat loss = 0.0;boost::shared_ptr<caffe::MemoryDataLayer<float> > memory_data_layer;memory_data_layer = boost::static_pointer_cast<caffe::MemoryDataLayer<float>>(caffe_net.layer_by_name("data"));vector<int> labels(patches.size());memory_data_layer->AddMatVector(patches, labels);// Net forwardconst vector<Blob<float>*> & results = caffe_net.ForwardPrefilled(&loss);float *output = results[1]->mutable_cpu_data();float tmp = -1;int pos = -1;// Display the outputstd::cout << "actuarl digit is: " << i << std::endl;for (int j = 0; j < 10; j++) {printf("Probability to be Number %d is %.3f\n", j, output[j]);if (tmp < output[j]) {pos = j;tmp = output[j];}}result[i] = pos;}for (int i = 0; i < 10; i++) {std::cout << "actual digit is : " << target[i] << ", result digit is: " << result[i] << std::endl;}return 0;}RegisterBrewFunction(test);int main(int argc, char* argv[]){// //07/16/caffe-vs-opencv-in-windows-tutorial-ii/// //01/11/build-caffe-in-windows-with-visual-studio--cuda-6-5-opencv-2-4-9/// /BVLC/caffe/issues/2499// /entry/139417// /tag/caffe// /BVLC/caffe/pull/1907argc = 2;#ifdef _DEBUG argv[0] = "E:/GitCode/Caffe_Test/lib/dbg/x86_vc12/test_mnist[dbg_x86_vc12].exe";#else argv[0] = "E:/GitCode/Caffe_Test/lib/rel/x86_vc12/test_mnist[rel_x86_vc12].exe";#endif argv[1] = "test";// 每个进程中至少要执行1次InitGoogleLogging(),否则不产生日志文件google::InitGoogleLogging(argv[0]);// 设置日志文件保存目录,此目录必须是已经存在的FLAGS_log_dir = "E:\\GitCode\\Caffe_Test\\test_data";FLAGS_max_log_size = 1024;//MB// Print output to stderr (while still logging).FLAGS_alsologtostderr = 1;// Usage message.gflags::SetUsageMessage("command line brew\n""usage: caffe <command> <args>\n\n""commands:\n"" test score a model");// Run tool or show usage.//caffe::GlobalInit(&argc, &argv);// 解析命令行参数 gflags::ParseCommandLineFlags(&argc, &argv, true);if (argc == 2) {return GetBrewFunction(caffe::string(argv[1]))();}else {gflags::ShowUsageWithFlagsRestrict(argv[0], "tools/caffe");}std::cout << "OK!!!" << std::endl;return 0;}

结果图如下:

通过结果发现,准确率为70%,错误的将6、8、9分别误识别为8、2、1。

GitHub:/fengbingchun/Caffe_Test

本内容不代表本网观点和政治立场,如有侵犯你的权益请联系我们处理。
网友评论
网友评论仅供其表达个人看法,并不表明网站立场。