Skip to content

Commit 93661a1

Browse files
author
dmitry.razdoburdin
committed
Fixing the UB bug
1 parent e18e342 commit 93661a1

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

plugin/updater_oneapi/regression_obj_oneapi.cc

+3-2
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,8 @@ class RegLossObjOneAPI : public ObjFunction {
7171
sycl::buffer<bst_float, 1> weights_buf(is_null_weight ? NULL : info.weights_.HostPointer(),
7272
is_null_weight ? 1 : info.weights_.Size());
7373

74+
const size_t n_targets = std::max(info.labels.Shape(1), static_cast<size_t>(1));
75+
7476
sycl::buffer<int, 1> additional_input_buf(1);
7577
{
7678
auto additional_input_acc = additional_input_buf.get_access<sycl::access::mode::write>();
@@ -92,7 +94,7 @@ class RegLossObjOneAPI : public ObjFunction {
9294
cgh.parallel_for<>(sycl::range<1>(ndata), [=](sycl::id<1> pid) {
9395
int idx = pid[0];
9496
bst_float p = Loss::PredTransform(preds_acc[idx]);
95-
bst_float w = is_null_weight ? 1.0f : weights_acc[idx];
97+
bst_float w = is_null_weight ? 1.0f : weights_acc[idx/n_targets];
9698
bst_float label = labels_acc[idx];
9799
if (label == 1.0f) {
98100
w *= scale_pos_weight;
@@ -125,7 +127,6 @@ class RegLossObjOneAPI : public ObjFunction {
125127

126128
void PredTransform(HostDeviceVector<float> *io_preds) const override {
127129
size_t const ndata = io_preds->Size();
128-
129130
sycl::buffer<bst_float, 1> io_preds_buf(io_preds->HostPointer(), io_preds->Size());
130131

131132
qu_.submit([&](sycl::handler& cgh) {

0 commit comments

Comments
 (0)