Skip to content

Commit

Permalink
add merge mask option
Browse files Browse the repository at this point in the history
  • Loading branch information
ZHEQIUSHUI committed Oct 17, 2023
1 parent c280c8f commit dacdc31
Show file tree
Hide file tree
Showing 6 changed files with 58 additions and 800 deletions.
1 change: 1 addition & 0 deletions qtproj/SAMQT/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
CMakeLists.txt.user*
345 changes: 0 additions & 345 deletions qtproj/SAMQT/CMakeLists.txt.user

This file was deleted.

436 changes: 0 additions & 436 deletions qtproj/SAMQT/CMakeLists.txt.user.5b6e006

This file was deleted.

6 changes: 3 additions & 3 deletions qtproj/SAMQT/mainwindow.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ void MainWindow::on_btn_remove_obj_clicked()
dilate_size = 111;
if (dilate_size < 5)
dilate_size = 5;
this->ui->label->ShowRemoveObject(dilate_size, this->ui->progressBar_remove_obj);
this->ui->label->ShowRemoveObject(dilate_size, this->ui->progressBar_remove_obj, ui->ch_merge_mask->isChecked());
this->setEnabled(true);
}

Expand Down Expand Up @@ -122,11 +122,11 @@ void MainWindow::on_btn_save_img_clicked()
}
else
{
if(!(filename.endsWith(".bmp") || filename.endsWith(".png") || filename.endsWith(".jpg")))
if (!(filename.endsWith(".bmp") || filename.endsWith(".png") || filename.endsWith(".jpg")))
{
filename += ".png";
}

if (!(cur_image.save(filename))) // 保存图像
{
QMessageBox::information(this,
Expand Down
7 changes: 7 additions & 0 deletions qtproj/SAMQT/mainwindow.ui
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,13 @@
</item>
</layout>
</item>
<item>
<widget class="QCheckBox" name="ch_merge_mask">
<property name="text">
<string>MergeMask</string>
</property>
</widget>
</item>
<item>
<widget class="QPushButton" name="btn_remove_obj">
<property name="enabled">
Expand Down
63 changes: 47 additions & 16 deletions qtproj/SAMQT/myqlabel.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ class myQLabel : public QLabel
bool mouseHolding = false;
QPoint pt_img_first, pt_img_secend;
SAM mSam;
// LamaInpaintOnnx mInpaint;
// LamaInpaintOnnx mInpaint;
std::shared_ptr<LamaInpaint> mInpaint;

void dragEnterEvent(QDragEnterEvent *event) override
Expand Down Expand Up @@ -221,7 +221,7 @@ class myQLabel : public QLabel
void InitModel(std::string encoder_model, std::string decoder_model, std::string inpaint_model)
{
mSam.Load(encoder_model, decoder_model);
// mInpaint.Load(inpaint_model);
// mInpaint.Load(inpaint_model);

if (string_utility<std::string>::ends_with(inpaint_model, ".onnx"))
{
Expand All @@ -238,7 +238,7 @@ class myQLabel : public QLabel
mInpaint->Load(inpaint_model);
}

void ShowRemoveObject(int dilate_size, QProgressBar *bar)
void ShowRemoveObject(int dilate_size, QProgressBar *bar, bool remove_mask_by_merge = true)
{
if (!cur_image.bits() || !grab_masks.size())
{
Expand All @@ -262,21 +262,52 @@ class myQLabel : public QLabel
bar->setMinimum(0);
bar->setMaximum(grab_masks.size());
}
for (auto grab_mask : grab_masks)
if (remove_mask_by_merge)
{
auto time_start = std::chrono::high_resolution_clock::now();
inpainted = mInpaint->Inpaint(inpainted, grab_mask, dilate_size);
auto time_end = std::chrono::high_resolution_clock::now();
std::chrono::duration<double> diff = time_end - time_start;
std::cout << "Inpaint Inference Cost time : " << diff.count() << "s" << std::endl;
QImage qinpainted(inpainted.data, inpainted.cols, inpainted.rows, inpainted.step1(), QImage::Format_BGR888);
cur_image = qinpainted.copy();
if (cur_masks.size())
cur_masks.removeFirst();
repaint();
if (bar)
bar->setValue(bar->value() + 1);
if (grab_masks.size())
{
auto base_mask = grab_masks[0];
if (bar)
bar->setValue(bar->value() + 1);

// merge all mask
for (size_t i = 1; i < grab_masks.size(); i++)
{
base_mask |= grab_masks[i];
if (bar)
bar->setValue(bar->value() + 1);
}

auto time_start = std::chrono::high_resolution_clock::now();
inpainted = mInpaint->Inpaint(inpainted, base_mask, dilate_size);
auto time_end = std::chrono::high_resolution_clock::now();
std::chrono::duration<double> diff = time_end - time_start;
std::cout << "Inpaint Inference Cost time : " << diff.count() << "s" << std::endl;
QImage qinpainted(inpainted.data, inpainted.cols, inpainted.rows, inpainted.step1(), QImage::Format_BGR888);
cur_image = qinpainted.copy();
cur_masks.clear();
repaint();
}
}
else
{
for (auto grab_mask : grab_masks)
{
auto time_start = std::chrono::high_resolution_clock::now();
inpainted = mInpaint->Inpaint(inpainted, grab_mask, dilate_size);
auto time_end = std::chrono::high_resolution_clock::now();
std::chrono::duration<double> diff = time_end - time_start;
std::cout << "Inpaint Inference Cost time : " << diff.count() << "s" << std::endl;
QImage qinpainted(inpainted.data, inpainted.cols, inpainted.rows, inpainted.step1(), QImage::Format_BGR888);
cur_image = qinpainted.copy();
if (cur_masks.size())
cur_masks.removeFirst();
repaint();
if (bar)
bar->setValue(bar->value() + 1);
}
}

cur_masks.clear();
rgba_masks.clear();
grab_masks.clear();
Expand Down

0 comments on commit dacdc31

Please sign in to comment.