Fix bug in handmade_aligned_realloc

This commit is contained in:
Charles Schlosser 2022-11-18 22:35:31 +00:00 committed by Antonio Sánchez
parent 6728683938
commit 044f3f6234

View File

@ -101,12 +101,12 @@ inline void throw_std_bad_alloc()
*/
EIGEN_DEVICE_FUNC inline void* handmade_aligned_malloc(std::size_t size, std::size_t alignment = EIGEN_DEFAULT_ALIGN_BYTES)
{
eigen_assert(alignment >= sizeof(void*) && (alignment & (alignment-1)) == 0 && "Alignment must be at least sizeof(void*) and a power of 2");
eigen_assert(alignment >= sizeof(void*) && alignment <= 128 && (alignment & (alignment-1)) == 0 && "Alignment must be at least sizeof(void*), less than or equal to 128, and a power of 2");
void* original = std::malloc(size + alignment);
if (original == 0) return 0;
uint8_t offset = alignment - (reinterpret_cast<std::size_t>(original) & (alignment - 1));
void* aligned = static_cast<char*>(original) + offset;
*(static_cast<char*>(aligned) - 1) = offset;
uint8_t offset = static_cast<uint8_t>(alignment - (reinterpret_cast<std::size_t>(original) & (alignment - 1)));
void* aligned = static_cast<void*>(static_cast<uint8_t*>(original) + offset);
*(static_cast<uint8_t*>(aligned) - 1) = offset;
return aligned;
}
@ -114,8 +114,8 @@ EIGEN_DEVICE_FUNC inline void* handmade_aligned_malloc(std::size_t size, std::si
EIGEN_DEVICE_FUNC inline void handmade_aligned_free(void *ptr)
{
if (ptr) {
uint8_t offset = *(static_cast<char*>(ptr) - 1);
void* original = static_cast<char*>(ptr) - offset;
uint8_t offset = static_cast<uint8_t>(*(static_cast<uint8_t*>(ptr) - 1));
void* original = static_cast<void*>(static_cast<uint8_t*>(ptr) - offset);
std::free(original);
}
}
@ -125,21 +125,23 @@ EIGEN_DEVICE_FUNC inline void handmade_aligned_free(void *ptr)
* Since we know that our handmade version is based on std::malloc
* we can use std::realloc to implement efficient reallocation.
*/
EIGEN_DEVICE_FUNC inline void* handmade_aligned_realloc(void* ptr, std::size_t size, std::size_t alignment = EIGEN_DEFAULT_ALIGN_BYTES)
EIGEN_DEVICE_FUNC inline void* handmade_aligned_realloc(void* ptr, std::size_t new_size, std::size_t old_size, std::size_t alignment = EIGEN_DEFAULT_ALIGN_BYTES)
{
if (ptr == 0) return handmade_aligned_malloc(size, alignment);
uint8_t previous_offset = *(static_cast<char*>(ptr) - 1);
void* previous_original = static_cast<char*>(ptr) - previous_offset;
void* original = std::realloc(previous_original, size + alignment);
if (ptr == 0) return handmade_aligned_malloc(new_size, alignment);
uint8_t old_offset = *(static_cast<uint8_t*>(ptr) - 1);
void* old_original = static_cast<uint8_t*>(ptr) - old_offset;
void* original = std::realloc(old_original, new_size + alignment);
if (original == 0) return 0;
if (original != previous_original) {
uint8_t offset = alignment - (reinterpret_cast<std::size_t>(original) & (alignment - 1));
void* aligned = static_cast<char*>(original) + offset;
std::memmove(aligned, ptr, size);
*(static_cast<char*>(aligned) - 1) = offset;
return aligned;
if (original == old_original) return ptr;
uint8_t offset = static_cast<uint8_t>(alignment - (reinterpret_cast<std::size_t>(original) & (alignment - 1)));
void* aligned = static_cast<void*>(static_cast<uint8_t*>(original) + offset);
if (offset != old_offset) {
const void* src = static_cast<const void*>(static_cast<uint8_t*>(original) + old_offset);
std::size_t count = (std::min)(new_size, old_size);
std::memmove(aligned, src, count);
}
return ptr;
*(static_cast<uint8_t*>(aligned) - 1) = offset;
return aligned;
}
/*****************************************************************************
@ -217,13 +219,12 @@ EIGEN_DEVICE_FUNC inline void aligned_free(void *ptr)
EIGEN_DEVICE_FUNC inline void* aligned_realloc(void *ptr, std::size_t new_size, std::size_t old_size)
{
if (ptr == 0) return aligned_malloc(new_size);
EIGEN_UNUSED_VARIABLE(old_size)
void *result;
#if (EIGEN_DEFAULT_ALIGN_BYTES==0) || EIGEN_MALLOC_ALREADY_ALIGNED
EIGEN_UNUSED_VARIABLE(old_size)
result = std::realloc(ptr,new_size);
#else
result = handmade_aligned_realloc(ptr,new_size);
result = handmade_aligned_realloc(ptr,new_size,old_size);
#endif
if (!result && new_size)