Skip to content

Commit acd34f0

Browse files
committed
progress monitoring cleanup, fix cancellation issues
1 parent 5af7fa3 commit acd34f0

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

69 files changed

+362
-263
lines changed

CHANGELOG.md

+6
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,12 @@
11
Version History
22
---------------
33

4+
- Fixed issues with cancellation through progress monitor callbacks:
5+
- Fixed cancellation requests almost never being fulfilled on CPU
6+
devices since `v2.3.0`
7+
- Fixed not calling the callback anymore after requesting cancellation,
8+
while the operation is still being executed
9+
410
### Changes in v2.3.0:
511

612
- Significantly improved image quality of the `RT` filter in *high* quality

api/api.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -550,7 +550,7 @@ OIDN_API_NAMESPACE_BEGIN
550550
checkHandle(hDevice);
551551
OIDN_LOCK_DEVICE(device);
552552
device->checkCommitted();
553-
device->wait();
553+
device->waitAndThrow();
554554
OIDN_CATCH_DEVICE(device)
555555
}
556556

apps/oidnTest.cpp

+31-20
Original file line numberDiff line numberDiff line change
@@ -941,23 +941,11 @@ bool progressCallback(void* userPtr, double n)
941941
return n < progress->nMax; // cancel if reached nMax
942942
}
943943

944-
void progressTest(DeviceRef& device, double nMax = 1000)
944+
void progressTest(DeviceRef& device, FilterRef& filter, double nMax = 1000)
945945
{
946-
const int W = 1283;
947-
const int H = 727;
948-
949-
FilterRef filter = device.newFilter("RT");
950-
REQUIRE(bool(filter));
951-
952-
auto image = makeConstImage(device, W, H);
953-
setFilterImage(filter, "color", image);
954-
setFilterImage(filter, "output", image); // in-place
955-
956946
Progress progress(nMax);
957947
filter.setProgressMonitorFunction(progressCallback, &progress);
958948

959-
filter.set("maxMemoryMB", 0); // make sure there will be multiple tiles
960-
961949
filter.commit();
962950
REQUIRE(device.getError() == Error::None);
963951

@@ -968,8 +956,7 @@ void progressTest(DeviceRef& device, double nMax = 1000)
968956
// Execution should be cancelled but it's not guaranteed
969957
Error error = device.getError();
970958
REQUIRE((error == Error::None || error == Error::Cancelled));
971-
// Check whether the callback has not been called after requesting cancellation
972-
REQUIRE(progress.n >= nMax);
959+
REQUIRE((progress.n >= nMax && progress.n <= 1));
973960
}
974961
else
975962
{
@@ -981,26 +968,50 @@ void progressTest(DeviceRef& device, double nMax = 1000)
981968

982969
TEST_CASE("progress monitor", "[progress]")
983970
{
971+
const int W = 1283;
972+
const int H = 727;
973+
984974
DeviceRef device = makeAndCommitDevice();
985975

986-
SECTION("progress monitor: complete")
976+
FilterRef filter = device.newFilter("RT");
977+
REQUIRE(bool(filter));
978+
979+
auto image = makeConstImage(device, W, H);
980+
setFilterImage(filter, "color", image);
981+
setFilterImage(filter, "output", image); // in-place
982+
983+
filter.set("maxMemoryMB", 0); // make sure there will be multiple tiles
984+
985+
SECTION("progress monitor: finish")
987986
{
988-
progressTest(device);
987+
progressTest(device, filter);
989988
}
990989

991990
SECTION("progress monitor: cancel at the middle")
992991
{
993-
progressTest(device, 0.5);
992+
progressTest(device, filter, 0.5);
994993
}
995994

996995
SECTION("progress monitor: cancel at the beginning")
997996
{
998-
progressTest(device, 0);
997+
progressTest(device, filter, 0);
999998
}
1000999

10011000
SECTION("progress monitor: cancel at the end")
10021001
{
1003-
progressTest(device, 1);
1002+
progressTest(device, filter, 1);
1003+
}
1004+
1005+
SECTION("progress monitor: cancel around the middle, finish")
1006+
{
1007+
progressTest(device, filter, 0.4);
1008+
progressTest(device, filter);
1009+
}
1010+
1011+
SECTION("progress monitor: finish, cancel around the middle")
1012+
{
1013+
progressTest(device, filter);
1014+
progressTest(device, filter, 0.6);
10041015
}
10051016
}
10061017

core/CMakeLists.txt

+1
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ set(OIDN_CORE_SOURCES
4848
module.h
4949
module.cpp
5050
op.h
51+
op.cpp
5152
output_process.h
5253
output_process.cpp
5354
pool.h

core/autoexposure.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ OIDN_NAMESPACE_BEGIN
2020

2121
#if !defined(OIDN_COMPILE_METAL_DEVICE)
2222

23-
class Autoexposure : public Op, public AutoexposureParams
23+
class Autoexposure : public BaseOp, public AutoexposureParams
2424
{
2525
public:
2626
explicit Autoexposure(const ImageDesc& srcDesc)

core/concat_conv.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ OIDN_NAMESPACE_BEGIN
1818
bool fastMath; // prefer performance over accuracy
1919
};
2020

21-
class ConcatConv : public Op, protected ConcatConvDesc
21+
class ConcatConv : public BaseOp, protected ConcatConvDesc
2222
{
2323
public:
2424
ConcatConv(const ConcatConvDesc& desc);

core/concat_conv_chw.h

+4-2
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,15 @@ OIDN_NAMESPACE_BEGIN
1414
public:
1515
ConcatConvCHW(Engine* engine, const ConcatConvDesc& desc);
1616

17-
size_t getScratchByteSize() const override { return conv->getScratchByteSize(); }
17+
Engine* getEngine() const override { return conv->getEngine(); }
18+
19+
size_t getScratchByteSize() override { return conv->getScratchByteSize(); }
1820
void setScratch(const Ref<Buffer>& scratch) override { conv->setScratch(scratch); }
1921

2022
void setWeight(const Ref<Tensor>& weight) { conv->setWeight(weight); }
2123

2224
void finalize() override { conv->finalize(); }
23-
void submit() override { conv->submit(); }
25+
void submitKernels(const Ref<CancellationToken>& ct) override { conv->submitKernels(ct); }
2426

2527
private:
2628
void updateSrc() override;

core/concat_conv_hwc.cpp

+4-4
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ OIDN_NAMESPACE_BEGIN
3636
return conv1->isSupported() && conv2->isSupported();
3737
}
3838

39-
size_t ConcatConvHWC::getScratchByteSize() const
39+
size_t ConcatConvHWC::getScratchByteSize()
4040
{
4141
return max(conv1->getScratchByteSize(), conv2->getScratchByteSize());
4242
}
@@ -73,10 +73,10 @@ OIDN_NAMESPACE_BEGIN
7373
conv2->finalize();
7474
}
7575

76-
void ConcatConvHWC::submit()
76+
void ConcatConvHWC::submitKernels(const Ref<CancellationToken>& ct)
7777
{
78-
conv1->submit();
79-
conv2->submit();
78+
conv1->submitKernels(ct);
79+
conv2->submitKernels(ct);
8080
}
8181

8282
OIDN_NAMESPACE_END

core/concat_conv_hwc.h

+3-2
Original file line numberDiff line numberDiff line change
@@ -15,17 +15,18 @@ OIDN_NAMESPACE_BEGIN
1515
public:
1616
ConcatConvHWC(Engine* engine, const ConcatConvDesc& desc);
1717

18+
Engine* getEngine() const override { return conv1->getEngine(); }
1819
bool isSupported() const override;
1920

20-
size_t getScratchByteSize() const override;
21+
size_t getScratchByteSize() override;
2122
void setScratch(const Ref<Buffer>& scratch) override;
2223

2324
TensorDesc getWeight1Desc() const { return weight1Desc; }
2425
TensorDesc getWeight2Desc() const { return weight2Desc; }
2526
void setWeight(const Ref<Tensor>& weight1, const Ref<Tensor>& weight2);
2627

2728
void finalize() override;
28-
void submit() override;
29+
void submitKernels(const Ref<CancellationToken>& ct) override;
2930

3031
private:
3132
void updateSrc() override;

core/conv.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ OIDN_NAMESPACE_BEGIN
3434
};
3535

3636
// Convolution
37-
class Conv : public Op, protected ConvDesc
37+
class Conv : public BaseOp, protected ConvDesc
3838
{
3939
public:
4040
Conv(const ConvDesc& desc);

core/device.cpp

+27
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
#include "device.h"
55
#include "subdevice.h"
6+
#include "engine.h"
67
#include "context.h"
78
#include "rt_filter.h"
89
#include "rtlightmap_filter.h"
@@ -159,6 +160,17 @@ OIDN_NAMESPACE_BEGIN
159160
}
160161
}
161162

163+
void Device::setAsyncError(Error code, const std::string& message)
164+
{
165+
// Update the stored async error only if the previous error was thrown
166+
std::lock_guard<std::mutex> lock(asyncErrorMutex);
167+
if (asyncError.code == Error::None)
168+
{
169+
asyncError.code = code;
170+
asyncError.message = message;
171+
}
172+
}
173+
162174
void Device::setErrorFunction(ErrorFunction func, void* userPtr)
163175
{
164176
errorFunc = func;
@@ -291,4 +303,19 @@ OIDN_NAMESPACE_BEGIN
291303
subdevice->trimScratch();
292304
}
293305

306+
void Device::waitAndThrow()
307+
{
308+
wait();
309+
310+
// If an asynchronous error was stored, throw it now
311+
std::lock_guard<std::mutex> asyncErrorLock(asyncErrorMutex);
312+
if (asyncError.code != Error::None)
313+
{
314+
const Error code = asyncError.code;
315+
const std::string message = std::move(asyncError.message);
316+
asyncError = {}; // clear the error
317+
throw Exception(code, message);
318+
}
319+
}
320+
294321
OIDN_NAMESPACE_END

core/device.h

+10
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,8 @@ OIDN_NAMESPACE_BEGIN
5454

5555
static void setError(Device* device, Error code, const std::string& message);
5656
static Error getError(Device* device, const char** outMessage);
57+
58+
void setAsyncError(Error code, const std::string& message);
5759
void setErrorFunction(ErrorFunction func, void* userPtr);
5860

5961
// Some devices (e.g. CUDA, HIP) need to change some per-thread state, which must be later restored
@@ -123,6 +125,10 @@ OIDN_NAMESPACE_BEGIN
123125
// Waits for all previously submitted commands to complete (blocks)
124126
virtual void wait() = 0;
125127

128+
// Waits for all previously submitted commands to complete, and throws the first asynchronous
129+
// error that occured since the previous invocation of this function (blocks)
130+
void waitAndThrow();
131+
126132
protected:
127133
virtual void init() = 0;
128134

@@ -156,7 +162,11 @@ OIDN_NAMESPACE_BEGIN
156162
};
157163

158164
static thread_local ErrorState globalError;
165+
159166
ThreadLocal<ErrorState> error;
167+
ErrorState asyncError;
168+
std::mutex asyncErrorMutex;
169+
160170
ErrorFunction errorFunc = nullptr;
161171
void* errorUserPtr = nullptr;
162172
};

core/engine.h

+3-1
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#include "heap.h"
1010
#include "buffer.h"
1111
#include "image.h"
12+
#include "progress.h"
1213

1314
OIDN_NAMESPACE_BEGIN
1415

@@ -80,7 +81,8 @@ OIDN_NAMESPACE_BEGIN
8081
virtual void submitUSMCopy(void* dstPtr, const void* srcPtr, size_t byteSize);
8182

8283
// Enqueues a host function
83-
virtual void submitHostFunc(std::function<void()>&& f) = 0;
84+
virtual void submitHostFunc(std::function<void()>&& f,
85+
const Ref<CancellationToken>& ct = nullptr) = 0;
8486

8587
// Issues all previously submitted commands (does not block)
8688
virtual void flush() {}

core/graph.cpp

+5-10
Original file line numberDiff line numberDiff line change
@@ -332,6 +332,7 @@ OIDN_NAMESPACE_BEGIN
332332
tensorScratchPlanner.addDepAllocs(opID, srcAllocIDs, concatSrcs);
333333

334334
ops.push_back(op);
335+
workAmount += op->getWorkAmount();
335336
dirty = true;
336337
}
337338

@@ -374,11 +375,6 @@ OIDN_NAMESPACE_BEGIN
374375
dirty = false;
375376
}
376377

377-
double Graph::getWorkAmount() const
378-
{
379-
return double(ops.size());
380-
}
381-
382378
bool Graph::isSupported() const
383379
{
384380
for (const auto& opTensorAllocPair : tensorAllocs)
@@ -423,6 +419,7 @@ OIDN_NAMESPACE_BEGIN
423419
scratch.reset();
424420
scratchByteSize = 0;
425421
privateByteSize = 0;
422+
workAmount = 0;
426423
tensorScratchByteOffset = 0;
427424
dirty = false;
428425
}
@@ -455,7 +452,7 @@ OIDN_NAMESPACE_BEGIN
455452
finalized = true;
456453
}
457454

458-
void Graph::submit(Progress& progress)
455+
void Graph::submit(const Ref<Progress>& progress)
459456
{
460457
if (!finalized)
461458
throw std::logic_error("graph not finalized");
@@ -468,14 +465,14 @@ OIDN_NAMESPACE_BEGIN
468465

469466
for (size_t i = 0; i < ops.size(); ++i)
470467
{
471-
ops[i]->submit();
468+
ops[i]->submit(progress);
472469

473470
#if defined(OIDN_MICROBENCH)
474471
engine->wait();
475472
const int numRuns = OIDN_MICROBENCH;
476473
Timer timer;
477474
for (int j = 0; j < numRuns; ++j)
478-
ops[i]->submit();
475+
ops[i]->submit(progress);
479476
engine->wait();
480477
const double time = timer.query() / numRuns;
481478
std::cerr << i << "," << ops[i]->getName() << "," << time * 1000 << std::endl;
@@ -505,8 +502,6 @@ OIDN_NAMESPACE_BEGIN
505502
dst->dump(toString(i) + "_" + ops[i]->getName() + "_");
506503
}
507504
#endif
508-
509-
progress.update(engine, 1);
510505
}
511506

512507
#if defined(OIDN_MICROBENCH)

0 commit comments

Comments
 (0)