File tree 1 file changed +3
-2
lines changed
1 file changed +3
-2
lines changed Original file line number Diff line number Diff line change @@ -71,6 +71,8 @@ class RegLossObjOneAPI : public ObjFunction {
71
71
sycl::buffer<bst_float, 1 > weights_buf (is_null_weight ? NULL : info.weights_ .HostPointer (),
72
72
is_null_weight ? 1 : info.weights_ .Size ());
73
73
74
+ const size_t n_targets = std::max (info.labels .Shape (1 ), static_cast <size_t >(1 ));
75
+
74
76
sycl::buffer<int , 1 > additional_input_buf (1 );
75
77
{
76
78
auto additional_input_acc = additional_input_buf.get_access <sycl::access ::mode::write >();
@@ -92,7 +94,7 @@ class RegLossObjOneAPI : public ObjFunction {
92
94
cgh.parallel_for <>(sycl::range<1 >(ndata), [=](sycl::id<1 > pid) {
93
95
int idx = pid[0 ];
94
96
bst_float p = Loss::PredTransform (preds_acc[idx]);
95
- bst_float w = is_null_weight ? 1 .0f : weights_acc[idx];
97
+ bst_float w = is_null_weight ? 1 .0f : weights_acc[idx/n_targets ];
96
98
bst_float label = labels_acc[idx];
97
99
if (label == 1 .0f ) {
98
100
w *= scale_pos_weight;
@@ -125,7 +127,6 @@ class RegLossObjOneAPI : public ObjFunction {
125
127
126
128
void PredTransform (HostDeviceVector<float > *io_preds) const override {
127
129
size_t const ndata = io_preds->Size ();
128
-
129
130
sycl::buffer<bst_float, 1 > io_preds_buf (io_preds->HostPointer (), io_preds->Size ());
130
131
131
132
qu_.submit ([&](sycl::handler& cgh) {
You can’t perform that action at this time.
0 commit comments