Skip to content

[SYCL] Fix weak_object and owner_less for device objects #8740

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion sycl/include/sycl/accessor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -535,7 +535,7 @@ class __SYCL_EXPORT LocalAccessorBaseHost {

protected:
template <class Obj>
friend decltype(Obj::impl) getSyclObjImpl(const Obj &SyclObject);
friend decltype(Obj::impl) detail::getSyclObjImpl(const Obj &SyclObject);

template <class T>
friend T detail::createSyclObjFromImpl(decltype(T::impl) ImplObj);
Expand Down Expand Up @@ -1209,6 +1209,9 @@ class __SYCL_EBO __SYCL_SPECIAL_CLASS __SYCL_TYPE(accessor) accessor :
friend class sycl::stream;
friend class sycl::ext::intel::esimd::detail::AccessorPrivateProxy;

template <class Obj>
friend decltype(Obj::impl) detail::getSyclObjImpl(const Obj &SyclObject);

template <class T>
friend T detail::createSyclObjFromImpl(decltype(T::impl) ImplObj);

Expand Down Expand Up @@ -2528,6 +2531,9 @@ class __SYCL_SPECIAL_CLASS local_accessor_base :
return Result;
}

template <class Obj>
friend decltype(Obj::impl) detail::getSyclObjImpl(const Obj &SyclObject);

template <class T>
friend T detail::createSyclObjFromImpl(decltype(T::impl) ImplObj);

Expand Down
7 changes: 7 additions & 0 deletions sycl/include/sycl/detail/owner_less_base.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,13 @@ template <class SyclObjT> class OwnerLessBase {
return getSyclObjImpl(*static_cast<const SyclObjT *>(this))
.owner_before(getSyclObjImpl(Other));
}
#else
// On device calls to these functions are disallowed, so declare them but
// don't define them to avoid compilation failures.
bool ext_oneapi_owner_before(
const ext::oneapi::detail::weak_object_base<SyclObjT> &Other)
const noexcept;
bool ext_oneapi_owner_before(const SyclObjT &Other) const noexcept;
#endif
};

Expand Down
18 changes: 16 additions & 2 deletions sycl/include/sycl/ext/oneapi/weak_object.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,12 +50,13 @@ class weak_object : public detail::weak_object_base<SYCLObjT> {

weak_object &operator=(const SYCLObjT &SYCLObj) noexcept {
// Create weak_ptr from the shared_ptr to SYCLObj's implementation object.
this->MObjWeakPtr = sycl::detail::getSyclObjImpl(SYCLObj);
this->MObjWeakPtr = GetWeakImpl(SYCLObj);
return *this;
}
weak_object &operator=(const weak_object &Other) noexcept = default;
weak_object &operator=(weak_object &&Other) noexcept = default;

#ifndef __SYCL_DEVICE_ONLY__
std::optional<SYCLObjT> try_lock() const noexcept {
auto MObjImplPtr = this->MObjWeakPtr.lock();
if (!MObjImplPtr)
Expand All @@ -69,6 +70,12 @@ class weak_object : public detail::weak_object_base<SYCLObjT> {
"Referenced object has expired.");
return *OptionalObj;
}
#else
// On device calls to these functions are disallowed, so declare them but
// don't define them to avoid compilation failures.
std::optional<SYCLObjT> try_lock() const noexcept;
SYCLObjT lock() const;
#endif // __SYCL_DEVICE_ONLY__
};

// Specialization of weak_object for buffer as it needs additional members
Expand Down Expand Up @@ -96,7 +103,7 @@ class weak_object<buffer<T, Dimensions, AllocatorT>>

weak_object &operator=(const buffer_type &SYCLObj) noexcept {
// Create weak_ptr from the shared_ptr to SYCLObj's implementation object.
this->MObjWeakPtr = sycl::detail::getSyclObjImpl(SYCLObj);
this->MObjWeakPtr = GetWeakImpl(SYCLObj);
this->MRange = SYCLObj.Range;
this->MOffsetInBytes = SYCLObj.OffsetInBytes;
this->MIsSubBuffer = SYCLObj.IsSubBuffer;
Expand All @@ -105,6 +112,7 @@ class weak_object<buffer<T, Dimensions, AllocatorT>>
weak_object &operator=(const weak_object &Other) noexcept = default;
weak_object &operator=(weak_object &&Other) noexcept = default;

#ifndef __SYCL_DEVICE_ONLY__
std::optional<buffer_type> try_lock() const noexcept {
auto MObjImplPtr = this->MObjWeakPtr.lock();
if (!MObjImplPtr)
Expand All @@ -119,6 +127,12 @@ class weak_object<buffer<T, Dimensions, AllocatorT>>
"Referenced object has expired.");
return *OptionalObj;
}
#else
// On device calls to these functions are disallowed, so declare them but
// don't define them to avoid compilation failures.
std::optional<buffer_type> try_lock() const noexcept;
buffer_type lock() const;
#endif // __SYCL_DEVICE_ONLY__

private:
// Additional members required for recreating buffers.
Expand Down
21 changes: 19 additions & 2 deletions sycl/include/sycl/ext/oneapi/weak_object_base.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ template <typename SYCLObjT> class weak_object_base {

constexpr weak_object_base() noexcept : MObjWeakPtr() {}
weak_object_base(const SYCLObjT &SYCLObj) noexcept
: MObjWeakPtr(sycl::detail::getSyclObjImpl(SYCLObj)) {}
: MObjWeakPtr(GetWeakImpl(SYCLObj)) {}
weak_object_base(const weak_object_base &Other) noexcept = default;
weak_object_base(weak_object_base &&Other) noexcept = default;

Expand All @@ -43,19 +43,36 @@ template <typename SYCLObjT> class weak_object_base {

bool expired() const noexcept { return MObjWeakPtr.expired(); }

#ifndef __SYCL_DEVICE_ONLY__
bool owner_before(const SYCLObjT &Other) const noexcept {
return MObjWeakPtr.owner_before(sycl::detail::getSyclObjImpl(Other));
return MObjWeakPtr.owner_before(GetWeakImpl(Other));
}
bool owner_before(const weak_object_base &Other) const noexcept {
return MObjWeakPtr.owner_before(Other.MObjWeakPtr);
}
#else
// On device calls to these functions are disallowed, so declare them but
// don't define them to avoid compilation failures.
bool owner_before(const SYCLObjT &Other) const noexcept;
bool owner_before(const weak_object_base &Other) const noexcept;
#endif // __SYCL_DEVICE_ONLY__

protected:
#ifndef __SYCL_DEVICE_ONLY__
// Store a weak variant of the impl in the SYCLObjT.
typename std::invoke_result_t<
decltype(sycl::detail::getSyclObjImpl<SYCLObjT>), SYCLObjT>::weak_type
MObjWeakPtr;

static decltype(MObjWeakPtr) GetWeakImpl(const SYCLObjT &SYCLObj) {
return sycl::detail::getSyclObjImpl(SYCLObj);
}
#else
// On device we may not have an impl, so we pad with an unused void pointer.
std::weak_ptr<void> MObjWeakPtr;
static std::weak_ptr<void> GetWeakImpl(const SYCLObjT &) { return {}; }
#endif // __SYCL_DEVICE_ONLY__

template <class Obj>
friend decltype(weak_object_base<Obj>::MObjWeakPtr)
detail::getSyclWeakObjImpl(const weak_object_base<Obj> &WeakObj);
Expand Down
28 changes: 28 additions & 0 deletions sycl/test-e2e/WeakObject/weak_object_copy.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
// RUN: %clangxx -fsycl -fsycl-targets=%sycl_triple %s -o %t.out
// RUN: %BE_RUN_PLACEHOLDER %t.out

// This test checks the behavior of the copy ctor and assignment operator for
// `weak_object`.

#include "weak_object_utils.hpp"

template <typename SyclObjT> struct WeakObjectCheckCopy {
void operator()(SyclObjT Obj) {
sycl::ext::oneapi::weak_object<SyclObjT> WeakObj{Obj};

sycl::ext::oneapi::weak_object<SyclObjT> WeakObjCopyCtor{WeakObj};
sycl::ext::oneapi::weak_object<SyclObjT> WeakObjCopyAssign = WeakObj;

assert(!WeakObjCopyCtor.expired());
assert(!WeakObjCopyAssign.expired());

assert(WeakObjCopyCtor.lock() == Obj);
assert(WeakObjCopyAssign.lock() == Obj);
}
};

int main() {
sycl::queue Q;
runTest<WeakObjectCheckCopy>(Q);
return 0;
}
22 changes: 22 additions & 0 deletions sycl/test-e2e/WeakObject/weak_object_expired.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
// RUN: %clangxx -fsycl -fsycl-targets=%sycl_triple %s -o %t.out
// RUN: %BE_RUN_PLACEHOLDER %t.out

// This test checks the behavior of `expired()` for `weak_object`.

#include "weak_object_utils.hpp"

template <typename SyclObjT> struct WeakObjectCheckExpired {
void operator()(SyclObjT Obj) {
sycl::ext::oneapi::weak_object<SyclObjT> WeakObj{Obj};
sycl::ext::oneapi::weak_object<SyclObjT> NullWeakObj;

assert(!WeakObj.expired());
assert(NullWeakObj.expired());
}
};

int main() {
sycl::queue Q;
runTest<WeakObjectCheckExpired>(Q);
return 0;
}
30 changes: 30 additions & 0 deletions sycl/test-e2e/WeakObject/weak_object_lock.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
// RUN: %clangxx -fsycl -fsycl-targets=%sycl_triple %s -o %t.out
// RUN: %BE_RUN_PLACEHOLDER %t.out

// This test checks the behavior of `lock()` for `weak_object`.

#include "weak_object_utils.hpp"

template <typename SyclObjT> struct WeakObjectCheckLock {
void operator()(SyclObjT Obj) {
sycl::ext::oneapi::weak_object<SyclObjT> WeakObj{Obj};
sycl::ext::oneapi::weak_object<SyclObjT> NullWeakObj;

SyclObjT LObj = WeakObj.lock();
assert(LObj == Obj);

try {
SyclObjT LNull = NullWeakObj.lock();
assert(false && "Locking empty weak object did not throw.");
} catch (sycl::exception &E) {
assert(E.code() == sycl::make_error_code(sycl::errc::invalid) &&
"Unexpected thrown error code.");
}
}
};

int main() {
sycl::queue Q;
runTest<WeakObjectCheckLock>(Q);
return 0;
}
31 changes: 31 additions & 0 deletions sycl/test-e2e/WeakObject/weak_object_move.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
// RUN: %clangxx -fsycl -fsycl-targets=%sycl_triple %s -o %t.out
// RUN: %BE_RUN_PLACEHOLDER %t.out

// This test checks the behavior of the copy ctor and assignment operator for
// `weak_object`.

#include "weak_object_utils.hpp"

template <typename SyclObjT> struct WeakObjectCheckMove {
void operator()(SyclObjT Obj) {
sycl::ext::oneapi::weak_object<SyclObjT> WeakObj1{Obj};
sycl::ext::oneapi::weak_object<SyclObjT> WeakObj2{Obj};

sycl::ext::oneapi::weak_object<SyclObjT> WeakObjMoveCtor{
std::move(WeakObj1)};
sycl::ext::oneapi::weak_object<SyclObjT> WeakObjMoveAssign =
std::move(WeakObj2);

assert(!WeakObjMoveCtor.expired());
assert(!WeakObjMoveAssign.expired());

assert(WeakObjMoveCtor.lock() == Obj);
assert(WeakObjMoveAssign.lock() == Obj);
}
};

int main() {
sycl::queue Q;
runTest<WeakObjectCheckMove>(Q);
return 0;
}
52 changes: 52 additions & 0 deletions sycl/test-e2e/WeakObject/weak_object_owner_before.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
// RUN: %clangxx -fsycl -fsycl-targets=%sycl_triple %s -o %t.out
// RUN: %BE_RUN_PLACEHOLDER %t.out

// This test checks the behavior of owner_before semantics for `weak_object`.

#include "weak_object_utils.hpp"

template <typename SyclObjT> struct WeakObjectCheckOwnerBefore {
void operator()(SyclObjT Obj) {
sycl::ext::oneapi::weak_object<SyclObjT> WeakObj{Obj};
sycl::ext::oneapi::weak_object<SyclObjT> NullWeakObj;

assert((WeakObj.owner_before(NullWeakObj) &&
!NullWeakObj.owner_before(WeakObj)) ||
(NullWeakObj.owner_before(WeakObj) &&
!WeakObj.owner_before(NullWeakObj)));

assert(!WeakObj.owner_before(Obj));
assert(!Obj.ext_oneapi_owner_before(WeakObj));

assert(!Obj.ext_oneapi_owner_before(Obj));
}
};

template <typename SyclObjT> struct WeakObjectCheckOwnerBeforeMulti {
void operator()(SyclObjT Obj1, SyclObjT Obj2) {
sycl::ext::oneapi::weak_object<SyclObjT> WeakObj1{Obj1};
sycl::ext::oneapi::weak_object<SyclObjT> WeakObj2{Obj2};

assert(
(WeakObj1.owner_before(WeakObj2) && !WeakObj2.owner_before(WeakObj1)) ||
(WeakObj2.owner_before(WeakObj1) && !WeakObj1.owner_before(WeakObj2)));

assert(!WeakObj1.owner_before(Obj1));
assert(!Obj1.ext_oneapi_owner_before(WeakObj1));

assert(!WeakObj2.owner_before(Obj2));
assert(!Obj2.ext_oneapi_owner_before(WeakObj2));

assert((Obj1.ext_oneapi_owner_before(Obj2) &&
!Obj2.ext_oneapi_owner_before(Obj1)) ||
(Obj2.ext_oneapi_owner_before(Obj1) &&
!Obj1.ext_oneapi_owner_before(Obj2)));
}
};

int main() {
sycl::queue Q;
runTest<WeakObjectCheckOwnerBefore>(Q);
runTestMulti<WeakObjectCheckOwnerBeforeMulti>(Q);
return 0;
}
Loading