13
13
14
14
#include " common.hpp"
15
15
#include " context.hpp"
16
+ #include " usm.hpp"
16
17
#include < cstdlib>
17
18
18
19
namespace umf {
@@ -27,7 +28,8 @@ static ur_result_t alloc_helper(ur_context_handle_t hContext,
27
28
auto alignment = (pUSMDesc && pUSMDesc->align ) ? pUSMDesc->align : 1u ;
28
29
UR_ASSERT (isPowerOf2 (alignment), UR_RESULT_ERROR_UNSUPPORTED_ALIGNMENT);
29
30
UR_ASSERT (ppMem, UR_RESULT_ERROR_INVALID_NULL_POINTER);
30
- // TODO: Check Max size when UR_DEVICE_INFO_MAX_MEM_ALLOC_SIZE is implemented
31
+ // TODO: Check Max size for host allocations when
32
+ // UR_DEVICE_INFO_MAX_MEM_ALLOC_SIZE is implemented
31
33
UR_ASSERT (size > 0 , UR_RESULT_ERROR_INVALID_USM_SIZE);
32
34
33
35
auto *ptr = hContext->add_alloc (alignment, type, size, nullptr );
@@ -49,8 +51,9 @@ UR_APIEXPORT ur_result_t UR_APICALL
49
51
urUSMDeviceAlloc (ur_context_handle_t hContext, ur_device_handle_t hDevice,
50
52
const ur_usm_desc_t *pUSMDesc, ur_usm_pool_handle_t pool,
51
53
size_t size, void **ppMem) {
52
- std::ignore = hDevice;
53
54
std::ignore = pool;
55
+ UR_ASSERT (size < native_cpu::detail::maxUSMAllocationSize (hDevice),
56
+ UR_RESULT_ERROR_INVALID_USM_SIZE);
54
57
55
58
return alloc_helper (hContext, pUSMDesc, size, ppMem, UR_USM_TYPE_DEVICE);
56
59
}
@@ -59,8 +62,9 @@ UR_APIEXPORT ur_result_t UR_APICALL
59
62
urUSMSharedAlloc (ur_context_handle_t hContext, ur_device_handle_t hDevice,
60
63
const ur_usm_desc_t *pUSMDesc, ur_usm_pool_handle_t pool,
61
64
size_t size, void **ppMem) {
62
- std::ignore = hDevice;
63
65
std::ignore = pool;
66
+ UR_ASSERT (size < native_cpu::detail::maxUSMAllocationSize (hDevice),
67
+ UR_RESULT_ERROR_INVALID_USM_SIZE);
64
68
65
69
return alloc_helper (hContext, pUSMDesc, size, ppMem, UR_USM_TYPE_SHARED);
66
70
}
@@ -155,3 +159,12 @@ UR_APIEXPORT ur_result_t UR_APICALL urUSMReleaseExp(ur_context_handle_t Context,
155
159
std::ignore = HostPtr;
156
160
DIE_NO_IMPLEMENTATION;
157
161
}
162
+
163
+ uint64_t maxUSMAllocationSize (const ur_device_handle_t &Device) {
164
+ size_t Global = Device->mem_size ;
165
+
166
+ auto QuarterGlobal = static_cast <uint32_t >(Global / 4u );
167
+
168
+ return std::max (std::min (1024u * 1024u * 1024u , QuarterGlobal),
169
+ 32u * 1024u * 1024u );
170
+ }
0 commit comments