-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathimplicit_als_train.h
57 lines (44 loc) · 1.25 KB
/
implicit_als_train.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
/***************************************************
*
* file: implicit_als_train_train.h
*
* Copyright (C) Angela Burova and Egor Smirnov 2017
*
****************************************************
*/
#ifndef _IMPLICIT_ALS_TRAIN_
#define _IMPLICIT_ALS_TRAIN_
#include "implicit_als_model.h"
namespace als
{
struct Parameter
{
double lambda = 0.001f;
size_t nIteration = 20;
size_t nFactors = 10;
};
template<typename FPType>
class Train
{
public:
Train(): _model(nullptr) {}
void compute(TablePtr<FPType> data);
ModelPtr<FPType> getModel()
{
return _model;
}
Parameter parameter;
protected:
ModelPtr<FPType> _model;
size_t _nFactors;
size_t _nUsers;
size_t _nItems;
void computeInternal(TablePtr<FPType>& dataPtr);
void initModel(TablePtr<FPType>& data);
void initItemsFactors(TablePtr<FPType>& itemsFactors, TablePtr<FPType>& dataPtr);
void updateFactors(TablePtr<FPType> otherFactors, TablePtr<FPType> currentFactors, TablePtr<FPType>& dataPtr);
TablePtr<FPType> getRegularization(const size_t nonZero);
std::pair<TablePtr<FPType>, TablePtr<FPType>> getSubMatrixes(TablePtr<FPType>& dataPtr, TablePtr<FPType> factors, const size_t idx);
};
} // als
#endif // _IMPLICIT_ALS_TRAIN_