Skip to content

Commit

Permalink
Relax thread-local constraint for the device shrable by multiple thre…
Browse files Browse the repository at this point in the history
…ads (#672)
  • Loading branch information
deukhyun-cha authored and kris-rowe committed Sep 12, 2023
1 parent f992782 commit 6b1c429
Show file tree
Hide file tree
Showing 7 changed files with 132 additions and 3 deletions.
8 changes: 8 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,14 @@ if(ENABLE_FORTRAN)
enable_language(Fortran)
endif()

option(ENABLE_SHARABLE_DEVICE "Enable sharable device by multiple threads" OFF)
if (ENABLE_SHARABLE_DEVICE)
add_compile_definitions(OCCA_THREAD_SHARABLE_ENABLED=1)
message("-- OCCA sharable by multi-threads : Enabled")
else()
add_compile_definitions(OCCA_THREAD_SHARABLE_ENABLED=0)
endif()

set(OCCA_SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR})
set(OCCA_BUILD_DIR ${CMAKE_BINARY_DIR})

Expand Down
1 change: 1 addition & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ MAKE_COMPILED_DEFINES := $(shell cat "$(OCCA_DIR)/scripts/build/compiledDefinesT
s,@@OCCA_OPENCL_ENABLED@@,$(OCCA_OPENCL_ENABLED),g;\
s,@@OCCA_METAL_ENABLED@@,$(OCCA_METAL_ENABLED),g;\
s,@@OCCA_DPCPP_ENABLED@@,$(OCCA_DPCPP_ENABLED),g;\
s,@@OCCA_THREAD_SHARABLE_ENABLED@@,$(OCCA_THREAD_SHARABLE_ENABLED),g;\
s,@@OCCA_BUILD_DIR@@,$(OCCA_BUILD_DIR),g;"\
> "$(NEW_COMPILED_DEFINES)")

Expand Down
17 changes: 17 additions & 0 deletions src/core/base.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,18 @@
#include <occa/internal/modes.hpp>
#include <occa/internal/utils/env.hpp>
#include <occa/internal/utils/sys.hpp>
#if ENABLE_THREAD_SHARABLE_OCCA
#include <occa/utils/mutex.hpp>
#endif

namespace occa {
//---[ Device Functions ]-------------
device host() {
#if ENABLE_THREAD_SHARABLE_OCCA
static device dev;
#else
thread_local device dev;
#endif
if (!dev.isInitialized()) {
dev = occa::device({
{"mode", "Serial"}
Expand All @@ -21,10 +28,20 @@ namespace occa {
}

device& getDevice() {
#if ENABLE_THREAD_SHARABLE_OCCA
static device dev;
static mutex_t mutex;
mutex.lock();
if (!dev.isInitialized()) {
dev = host();
}
mutex.unlock();
#else
thread_local device dev;
if (!dev.isInitialized()) {
dev = host();
}
#endif
return std::ref(dev);
}

Expand Down
13 changes: 13 additions & 0 deletions src/occa/internal/utils/env.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,25 @@
#include <occa/internal/utils/env.hpp>
#include <occa/internal/utils/sys.hpp>

#if OCCA_THREAD_SHARABLE_ENABLED
#include <occa/utils/mutex.hpp>
#endif

namespace occa {
json& settings() {
#if OCCA_THREAD_SHARABLE_ENABLED
static json props;
static mutex_t mutex;
mutex.lock();
#else
thread_local json props;
#endif
if (!props.size()) {
props = env::baseSettings();
}
#if OCCA_THREAD_SHARABLE_ENABLED
mutex.unlock();
#endif
return std::ref(props);
}

Expand Down
13 changes: 13 additions & 0 deletions src/occa/internal/utils/gc.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@

#include <occa/utils/gc.hpp>

#if OCCA_THREAD_SHARABLE_ENABLED
#include <occa/utils/mutex.hpp>
#endif

namespace occa {
namespace gc {
class withRefs {
Expand All @@ -27,6 +31,11 @@ namespace occa {

template <class entry_t>
class ring_t {
#if OCCA_THREAD_SHARABLE_ENABLED
private:
static mutex_t mutex;
#endif

public:
bool useRefs;
ringEntry_t *head;
Expand All @@ -37,7 +46,11 @@ namespace occa {
void clear();

void addRef(entry_t *entry);
#if OCCA_THREAD_SHARABLE_ENABLED
void removeRef(entry_t *entry, const bool threadLock = true);
#else
void removeRef(entry_t *entry);
#endif

bool needsFree() const;

Expand Down
48 changes: 46 additions & 2 deletions src/occa/internal/utils/gc.tpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@

namespace occa {
namespace gc {
#if OCCA_THREAD_SHARABLE_ENABLED
template <class entry_t>
mutex_t ring_t<entry_t>::mutex;
#endif
template <class entry_t>
ring_t<entry_t>::ring_t() :
useRefs(true),
Expand All @@ -20,21 +24,55 @@ namespace occa {

template <class entry_t>
void ring_t<entry_t>::addRef(entry_t *entry) {
#if OCCA_THREAD_SHARABLE_ENABLED
mutex.lock();
#endif
if (!entry || head == entry) {
#if OCCA_THREAD_SHARABLE_ENABLED
mutex.unlock();
#endif
return;
}
entry->removeRef();
if (!head) {
head = entry;
#if OCCA_THREAD_SHARABLE_ENABLED
mutex.unlock();
#endif
return;
}
ringEntry_t *tail = head->leftRingEntry;
entry->leftRingEntry = tail;
tail->rightRingEntry = entry;
head->leftRingEntry = entry;
entry->rightRingEntry = head;
#if OCCA_THREAD_SHARABLE_ENABLED
mutex.unlock();
#endif
}

#if OCCA_THREAD_SHARABLE_ENABLED
template <class entry_t>
void ring_t<entry_t>::removeRef(entry_t *entry, const bool threadLock) {
if (threadLock)
mutex.lock();
// Check if the ring is empty
if (!entry || !head) {
mutex.unlock();
return;
}
ringEntry_t *tail = head->leftRingEntry;
// Remove the entry ref from its ring
entry->removeRef();
if (head == entry) {
// Change the head to the tail if entry happened to be the old head
head = ((tail != entry)
? tail
: NULL);
}
mutex.unlock();
}
#else
template <class entry_t>
void ring_t<entry_t>::removeRef(entry_t *entry) {
// Check if the ring is empty
Expand All @@ -51,7 +89,7 @@ namespace occa {
: NULL);
}
}

#endif
template <class entry_t>
bool ring_t<entry_t>::needsFree() const {
// Object has no more references, safe to free now
Expand Down Expand Up @@ -102,10 +140,13 @@ namespace occa {
if (!entry) {
return;
}
#if OCCA_THREAD_SHARABLE_ENABLED
ring_t<entry_t>::mutex.lock();
#endif
typename entryRingMap_t::iterator it = rings.find(entry);
if (it != rings.end()) {
ring_t<entry_t> &ring = it->second;
ring.removeRef(entry);
ring.removeRef(entry, false);
rings.erase(it);
// Change key if head changed
if (ring.head &&
Expand All @@ -116,6 +157,9 @@ namespace occa {
} else {
entry->removeRef();
}
#if OCCA_THREAD_SHARABLE_ENABLED
ring_t<entry_t>::mutex.unlock();
#endif
}

template <class entry_t>
Expand Down
35 changes: 34 additions & 1 deletion tests/src/internal/utils/gc.cpp
Original file line number Diff line number Diff line change
@@ -1,16 +1,24 @@
#if ENABLE_THREAD_SHARABLE_OCCA
#include <thread>
#endif
#include <occa/defines.hpp>
#include <occa/internal/utils/gc.hpp>
#include <occa/internal/utils/testing.hpp>

void testWithRefs();
void testRingEntry();
void testRing();
#if ENABLE_THREAD_SHARABLE_OCCA
void testRingMultiThread();
#endif

int main(const int argc, const char **argv) {
testWithRefs();
testRingEntry();
testRing();

#if ENABLE_THREAD_SHARABLE_OCCA
testRingMultiThread();
#endif
return 0;
}

Expand Down Expand Up @@ -109,3 +117,28 @@ void testRing() {
(void*) NULL);
ASSERT_TRUE(values.needsFree());
}

#if ENABLE_THREAD_SHARABLE_OCCA
void testRingMultiThread() {
occa::gc::ring_t<occa::gc::ringEntry_t> values;
const int nEntry = 1000;

auto f = [&]() {
occa::gc::ringEntry_t e[nEntry];
for (auto i = 0; i < nEntry; i++) {
values.addRef(e+i);
values.removeRef(e+i);
}
};

std::thread th1(f);
std::thread th2(f);

th1.join();
th2.join();

ASSERT_EQ((void*) values.head,
(void*) NULL);
ASSERT_TRUE(values.needsFree());
}
#endif

0 comments on commit 6b1c429

Please sign in to comment.