forked from alibaba/MNN
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathquantized.cpp
60 lines (52 loc) · 1.87 KB
/
quantized.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
//
// quantized.cpp
// MNN
//
// Created by MNN on 2019/07/01.
// Copyright © 2018, Alibaba Group Holding Limited
//
#include <fstream>
#include <sstream>
#include <string>
#include "calibration.hpp"
#include "logkit.h"
int main(int argc, const char* argv[]) {
if (argc < 4) {
DLOG(INFO) << "Usage: ./quantized.out src.mnn dst.mnn preTreatConfig.json\n";
return 0;
}
const char* modelFile = argv[1];
const char* preTreatConfig = argv[3];
const char* dstFile = argv[2];
DLOG(INFO) << ">>> modelFile: " << modelFile;
DLOG(INFO) << ">>> preTreatConfig: " << preTreatConfig;
DLOG(INFO) << ">>> dstFile: " << dstFile;
std::unique_ptr<MNN::NetT> netT;
{
std::ifstream input(modelFile);
std::ostringstream outputOs;
outputOs << input.rdbuf();
netT = MNN::UnPackNet(outputOs.str().c_str());
}
// temp build net for inference
flatbuffers::FlatBufferBuilder builder(1024);
auto offset = MNN::Net::Pack(builder, netT.get());
builder.Finish(offset);
int size = builder.GetSize();
auto ocontent = builder.GetBufferPointer();
// model buffer for creating mnn Interpreter
std::unique_ptr<uint8_t> modelForInference(new uint8_t[size]);
memcpy(modelForInference.get(), ocontent, size);
std::unique_ptr<uint8_t> modelOriginal(new uint8_t[size]);
memcpy(modelOriginal.get(), ocontent, size);
netT.reset();
netT = MNN::UnPackNet(modelOriginal.get());
// quantize model's weight
DLOG(INFO) << "Calibrate the feature and quantize model...";
std::shared_ptr<Calibration> calibration(
new Calibration(netT.get(), modelForInference.get(), size, preTreatConfig, std::string(modelFile), std::string(dstFile)));
calibration->runQuantizeModel();
calibration->dumpTensorScales(dstFile);
DLOG(INFO) << "Quantize model done!";
return 0;
}