@@ -143,7 +143,6 @@ ITensor* residualDenseBlock(INetworkDefinition *network, std::map<std::string, W
143143 return ew1->getOutput (0 );
144144}
145145
146-
147146ITensor* RRDB (INetworkDefinition *network, std::map<std::string, Weights>& weightMap, ITensor* x, std::string lname)
148147{
149148 ITensor* out = residualDenseBlock (network, weightMap, x, lname + " .rdb1" );
@@ -253,7 +252,24 @@ void createEngine(unsigned int maxBatchSize, IBuilder* builder, IBuilderConfig*
253252
254253 // Build engine
255254 builder->setMaxBatchSize (maxBatchSize);
256- config->setMaxWorkspaceSize (1 << 20 );
255+ // config->setMaxWorkspaceSize(1 << 22);
256+ config->setMaxWorkspaceSize (28 * (1 << 23 )); // 28MB
257+
258+ if (precision_mode == 16 ) {
259+ std::cout << " ==== precision f16 ====" << std::endl << std::endl;
260+ config->setFlag (BuilderFlag::kFP16 );
261+ }
262+ else if (precision_mode == 8 ) {
263+ // std::cout << "==== precision int8 ====" << std::endl << std::endl;
264+ // std::cout << "Your platform support int8: " << builder->platformHasFastInt8() << std::endl;
265+ // assert(builder->platformHasFastInt8());
266+ // config->setFlag(BuilderFlag::kINT8);
267+ // Int8EntropyCalibrator2 *calibrator = new Int8EntropyCalibrator2(maxBatchSize, INPUT_W, INPUT_H, 0, "../data_calib/", "../Int8_calib_table/detr_int8_calib.table", INPUT_BLOB_NAME);
268+ // config->setInt8Calibrator(calibrator);
269+ }
270+ else {
271+ std::cout << " ==== precision f32 ====" << std::endl << std::endl;
272+ }
257273
258274 std::cout << " Building engine, please wait for a while..." << std::endl;
259275 IHostMemory* engine = builder->buildSerializedNetwork (*network, *config);
@@ -285,7 +301,7 @@ int main()
285301 char engineFileName[] = " real-esrgan" ;
286302
287303 char engine_file_path[256 ];
288- sprintf (engine_file_path, " ../Engine/%s .engine" , engineFileName);
304+ sprintf (engine_file_path, " ../Engine/%s_%d .engine" , engineFileName, precision_mode );
289305
290306 // 1) engine file 만들기
291307 // 강제 만들기 true면 무조건 다시 만들기
@@ -359,7 +375,7 @@ int main()
359375 std::cout << " ===== input load done =====" << std::endl << std::endl;
360376
361377 uint64_t dur_time = 0 ;
362- uint64_t iter_count = 1 ;
378+ uint64_t iter_count = 10 ;
363379
364380 // CUDA 스트림 생성
365381 cudaStream_t stream;
0 commit comments