Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable multiple precisions in CeedVector #948

Draft
wants to merge 12 commits into
base: main
Choose a base branch
from
Draft
43 changes: 43 additions & 0 deletions backends/hip-ref/ceed-hip-ref-vector.c
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,47 @@ static inline int CeedScalarTypeGetSize_Hip(Ceed ceed, CeedScalarType prec,
return CEED_ERROR_SUCCESS;
}

//------------------------------------------------------------------------------
// Get info about the current status of the different precisions in the
// valid, borrowed, and owned arrays, for a specific mem_type
//------------------------------------------------------------------------------
static int CeedVectorCheckArrayStatus_Hip(CeedVector vec,
CeedMemType mem_type,
unsigned int *valid_status,
unsigned int *borrowed_status,
unsigned int *owned_status) {

int ierr;
CeedVector_Hip *impl;
ierr = CeedVectorGetData(vec, &impl); CeedChkBackend(ierr);
*valid_status = 0;
*borrowed_status = 0;
*owned_status = 0;
switch(mem_type) {
case CEED_MEM_HOST:
for (int i = 0; i < CEED_NUM_PRECISIONS; i++) {
if (!!impl->h_array.values[i])
*valid_status += 1 << i;
if (!!impl->h_array_borrowed.values[i])
*borrowed_status += 1 << i;
if (!!impl->h_array_owned.values[i])
*owned_status += 1 << i;
}
break;
case CEED_MEM_DEVICE:
for (int i = 0; i < CEED_NUM_PRECISIONS; i++) {
if (!!impl->d_array.values[i])
*valid_status += 1 << i;
if (!!impl->d_array_borrowed.values[i])
*borrowed_status += 1 << i;
if (!!impl->d_array_owned.values[i])
*owned_status += 1 << i;
}
break;
}
return CEED_ERROR_SUCCESS;
}

//------------------------------------------------------------------------------
// Set all pointers as invalid
//------------------------------------------------------------------------------
Expand Down Expand Up @@ -1171,6 +1212,8 @@ int CeedVectorCreate_Hip(CeedSize n, CeedVector vec) {
ierr = CeedSetBackendFunction(ceed, "Vector", vec, "HasBorrowedArrayOfType",
CeedVectorHasBorrowedArrayOfType_Hip);
CeedChkBackend(ierr);
ierr = CeedSetBackendFunction(ceed, "Vector", vec, "CheckArrayStatus",
CeedVectorCheckArrayStatus_Hip); CeedChkBackend(ierr);
ierr = CeedSetBackendFunction(ceed, "Vector", vec, "SetArrayGeneric",
CeedVectorSetArrayGeneric_Hip); CeedChkBackend(ierr);
ierr = CeedSetBackendFunction(ceed, "Vector", vec, "TakeArrayGeneric",
Expand Down
2 changes: 2 additions & 0 deletions include/ceed-impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,8 @@ struct CeedVector_private {
Ceed ceed;
int (*HasValidArray)(CeedVector, bool *);
int (*HasBorrowedArrayOfType)(CeedVector, CeedMemType, CeedScalarType, bool *);
int (*CheckArrayStatus)(CeedVector, CeedMemType, unsigned int *, unsigned int *,
unsigned int *);
int (*SetArrayGeneric)(CeedVector, CeedMemType, CeedScalarType, CeedCopyMode,
void *);
int (*SetValue)(CeedVector, CeedScalar);
Expand Down
3 changes: 3 additions & 0 deletions include/ceed/ceed.h
Original file line number Diff line number Diff line change
Expand Up @@ -398,6 +398,9 @@ CEED_EXTERN int CeedVectorRestoreArrayReadFP32(CeedVector vec,
const float **array);
CEED_EXTERN int CeedVectorRestoreArrayReadFP64(CeedVector vec,
const double **array);
CEED_EXTERN int CeedVectorCheckArrayStatus(CeedVector vec, CeedMemType mem_type,
unsigned int *valid_status, unsigned int *borrowed_status,
unsigned int *owned_status);
CEED_EXTERN int CeedVectorNorm(CeedVector vec, CeedNormType type,
CeedScalar *norm);
CEED_EXTERN int CeedVectorScale(CeedVector x, CeedScalar alpha);
Expand Down
33 changes: 33 additions & 0 deletions interface/ceed-vector.c
Original file line number Diff line number Diff line change
Expand Up @@ -1070,6 +1070,39 @@ int CeedVectorRestoreArrayReadFP64(CeedVector vec, const double **array) {
return CeedVectorRestoreArrayReadGeneric(vec, (const void **) array);
}

/**
@brief Check the current status of the CeedVector's data arrays.

This function sets unsigned int parameters indicating for which precisions
the data for mem_type is not NULL, for borrowed and owned arrays, and also
which precisions are currently valid.
The values can be checked bitwise, such that, e.g.,
(valid_status & (1 << CEED_SCALAR_FP32)) will be true if the FP32 array is
currently valid, and false otherwise.

@param vec CeedVector for which to check status
@param mem_type Mem type for which to check status
@param valid_status Status indicator for valid data
@param borrowed_status Status indicator for borrowed arrays
@param owned_status Status indicator for owned arrays
**/
int CeedVectorCheckArrayStatus(CeedVector vec, CeedMemType mem_type,
Comment on lines +1074 to +1090
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added this function to the interface level, thinking it would be used by tests for multiprecision vector functionality (and potentially, users may want this feature).
It also has a backend implementation, so it's added to the CeedVector_private struct. Since this affects the user interface, I wanted to make sure and point it out to see what people thought.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we keep this, we may also want an interface function for convenient "pretty" printing of status for all vector arrays, or we could implement such a function in the tests directly.

unsigned int *valid_status,
unsigned int *borrowed_status,
unsigned int *owned_status) {
int ierr;

if (!vec->CheckArrayStatus)
// LCOV_EXCL_START
return CeedError(vec->ceed, CEED_ERROR_UNSUPPORTED,
"Backend does not support CheckArrayStatus");
// LCOV_EXCL_STOP

ierr = vec->CheckArrayStatus(vec, mem_type, valid_status, borrowed_status,
owned_status); CeedChk(ierr);
return CEED_ERROR_SUCCESS;
}

/**
@brief Get the norm of a CeedVector.

Expand Down
1 change: 1 addition & 0 deletions interface/ceed.c
Original file line number Diff line number Diff line change
Expand Up @@ -848,6 +848,7 @@ int CeedInit(const char *resource, Ceed *ceed) {
CEED_FTABLE_ENTRY(CeedVector, GetArrayWriteGeneric),
CEED_FTABLE_ENTRY(CeedVector, RestoreArrayGeneric),
CEED_FTABLE_ENTRY(CeedVector, RestoreArrayReadGeneric),
CEED_FTABLE_ENTRY(CeedVector, CheckArrayStatus),
CEED_FTABLE_ENTRY(CeedVector, Norm),
CEED_FTABLE_ENTRY(CeedVector, Scale),
CEED_FTABLE_ENTRY(CeedVector, AXPY),
Expand Down