add thrust
This commit is contained in:
@@ -31,3 +31,8 @@ set(LIBRARIES
|
||||
${ROOT_LIBRARIES}
|
||||
)
|
||||
uLib_add_tests(Core)
|
||||
|
||||
if(USE_CUDA)
|
||||
set_source_files_properties(VectorMetaAllocatorTest.cpp PROPERTIES LANGUAGE CUDA)
|
||||
endif()
|
||||
|
||||
|
||||
@@ -12,6 +12,15 @@
|
||||
#include "testing-prototype.h"
|
||||
#include <Core/Vector.h>
|
||||
|
||||
#ifdef USE_CUDA
|
||||
#include <thrust/device_ptr.h>
|
||||
#include <thrust/transform.h>
|
||||
|
||||
struct DoubleFunctor {
|
||||
__host__ __device__ int operator()(int x) const { return x * 2; }
|
||||
};
|
||||
#endif
|
||||
|
||||
int main() {
|
||||
BEGIN_TESTING(VectorMetaAllocator);
|
||||
|
||||
@@ -41,14 +50,31 @@ int main() {
|
||||
exit(1);
|
||||
}
|
||||
|
||||
// Verify DeviceData() matches GetVRAMData()
|
||||
{
|
||||
thrust::device_ptr<int> dev_ptr = v.DeviceData();
|
||||
if (dev_ptr.get() != vram_ptr) {
|
||||
std::cout << "Error: DeviceData() does not match GetVRAMData()!\n";
|
||||
exit(1);
|
||||
}
|
||||
std::cout << "DeviceData() matches GetVRAMData(). OK\n";
|
||||
}
|
||||
|
||||
// Use thrust::transform via DeviceBegin()/DeviceEnd() to double all elements
|
||||
// on device
|
||||
std::cout << "Doubling elements on device via thrust::transform...\n";
|
||||
thrust::transform(v.DeviceBegin(), v.DeviceEnd(), v.DeviceBegin(),
|
||||
DoubleFunctor{});
|
||||
|
||||
std::cout << "Moving back to RAM...\n";
|
||||
v.MoveToRAM();
|
||||
|
||||
std::cout << "RAM contents after VRAM trip: ";
|
||||
std::cout << "RAM contents after VRAM trip + thrust transform: ";
|
||||
for (size_t i = 0; i < v.size(); ++i) {
|
||||
std::cout << v[i] << " ";
|
||||
if (v[i] != (int)(i + 1)) {
|
||||
std::cout << "\nError: Data corrupted after RAM->VRAM->RAM trip at index "
|
||||
if (v[i] != (int)((i + 1) * 2)) {
|
||||
std::cout << "\nError: Data corrupted after RAM->VRAM->thrust->RAM trip "
|
||||
"at index "
|
||||
<< i << "\n";
|
||||
exit(1);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user