Skip to content

Commit 5dbea42

Browse files
committed
Modified KF to use functions from +plugins/+ClassKF
1 parent 40ce1aa commit 5dbea42

File tree

2 files changed

+190
-6
lines changed

2 files changed

+190
-6
lines changed

+inverse/@KalmanInverter/invert.asv

+184
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,184 @@
1+
%% Copyright © 2025- Joonas Lahtinen
2+
function [z_vec, self] = invert(self, f, L, procFile, source_direction_mode, source_positions, opts)
3+
4+
%
5+
% invert
6+
%
7+
% Builds a reconstruction of source dipoles from a given lead field with
8+
% the Kalman filtering method.
9+
%
10+
% Inputs:
11+
%
12+
% - self
13+
%
14+
% An instance of KalmanInverter with the method-specific parameters.
15+
%
16+
% - f
17+
%
18+
% Some vector.
19+
%
20+
% - L
21+
%
22+
% The lead field that is being inverted.
23+
%
24+
% - procFile
25+
%
26+
% A struct with source space indices.
27+
%
28+
% - source_direction_mode
29+
%
30+
% The way the orientations of the sources should be interpreted.
31+
%
32+
% - opts.use_gpu = false
33+
%
34+
% A logical flag for choosing whether a GPU will be used in
35+
% computations, if available.
36+
%
37+
% Outputs:
38+
%
39+
% - reconstruction
40+
%
41+
% The reconstrution of the dipoles.
42+
%
43+
44+
arguments
45+
46+
self (1,1) inverse.KalmanInverter
47+
48+
f (:,1) {mustBeA(f,["double","gpuArray"])}
49+
50+
L (:,:) {mustBeA(L,["double","gpuArray"])}
51+
52+
procFile (1,1) struct
53+
54+
source_direction_mode
55+
56+
source_positions
57+
58+
opts.use_gpu (1,1) logical = false
59+
60+
opts.normalize_data (1,1) double = 1
61+
62+
end
63+
64+
65+
% Get needed parameters from self and others.
66+
67+
snr_val = self.signal_to_noise_ratio;
68+
std_lhood = 10^(-self.signal_to_noise_ratio/20);
69+
pm_val = self.inv_prior_over_measurement_db;
70+
amplitude_db = self.inv_amplitude_db;
71+
pm_val = pm_val - amplitude_db;
72+
73+
% Then start inverting.
74+
%% CALCULATION STARTS HERE
75+
% m_0 = prior mean
76+
m = zeros(size(L,2), 1);
77+
78+
%Prior covariance is saved in self.prev_step_posterior_cov
79+
if isempty(self.prev_step_posterior_cov)
80+
if max(size(theta0)) == 1
81+
self.prev_step_posterior_cov = eye(size(L,2)) * theta0;
82+
else
83+
self.prev_step_posterior_cov = diag(theta0);
84+
end
85+
end
86+
87+
88+
if isempty(self.prev_step_reconstruction)
89+
if not(strcmp(self.method_type,"Ensembled Kalman filter"))
90+
self.prev_step_reconstruction = zeros(size(L,2),1);
91+
else
92+
self.prev_step_reconstruction = mvnrnd(zeros(size(L,2),1), self.prev_step_posterior_cov, self.number_of_ensembles)';
93+
end
94+
end
95+
96+
if not(isempty(self.evolution_var))
97+
self.evolution_cov = diag(self.evolution_var(:,1));
98+
self.evolution_var(:,1) = [];
99+
end
100+
101+
if opts.use_gpu && gpuDeviceCount > 0
102+
self.evolution_cov = gpuArray(self.evolution_cov);
103+
self.noise_cov = gpuArray(self.noise_cov);
104+
self.prev_step_posterior_cov = gpuArray(self.prev_step_posterior_cov);
105+
end
106+
%% KALMAN FILTER
107+
if strcmp(self.method_type,"Basic Kalman filter")
108+
% Prediction
109+
[x, P] = plugins.ClassKF.class_kf_predict(self);
110+
% Update
111+
[x, P] = plugins.ClassKF.kf_update(x, P, f, L, self.noise_cov);
112+
if self.use_smoothing
113+
self.posterior_covs = [self.posterior_covs,gather(P)];
114+
end
115+
z_vec = gather(x);
116+
self.prev_step_reconstruction = x;
117+
self.prev_step_posterior_cov = P;
118+
elseif strcmp(self.method_type,"Standardized Kalman filter")
119+
% Prediction
120+
[x, P] = plugins.ClassKF.class_kf_predict(self);
121+
% Update
122+
[x, P, ~, D] = plugins.ClassKF.kf_sL_update(x, gather(P), f, L, self.noise_cov);
123+
if self.use_smoothing
124+
self.posterior_covs = [self.posterior_covs,gather(P)];
125+
end
126+
self.prev_step_reconstruction = x;
127+
z_vec = gather(D*self.prev_step_reconstruction);
128+
self.prev_step_posterior_cov = P;
129+
elseif strcmp(self.method_type,"Approximated Standardized Kalman filter")
130+
% Prediction
131+
[x, P] = plugins.ClassKF.class_kf_predict(self);
132+
% Update
133+
[x, P, ~, D] = plugins.ClassKF.kf_sL_update_approx(x, P, f, L, self.noise_cov);
134+
if self.use_smoothing
135+
self.posterior_covs = [self.posterior_covs,gather(P)];
136+
end
137+
self.prev_step_reconstruction = x;
138+
z_vec = gather(D*self.prev_step_reconstruction);
139+
self.prev_step_posterior_cov = P;
140+
elseif strcmp(self.method_type,"Ensembled Kalman filter")
141+
w = mvnrnd(zeros(size(L,2),1), self.evolution_cov, self.number_of_ensembles)';
142+
% Forecasts
143+
x_f = self.state_transition_model_A * self.prev_step_reconstruction + w;
144+
C = cov(x_f');
145+
correlationLocalization = true;
146+
if correlationLocalization
147+
T = corrcoef(x_f');
148+
% explain How to find 0.05
149+
T(abs(T) < 0.05) = 0;
150+
C = C .* T;
151+
end
152+
v = mvnrnd(zeros(size(self.noise_cov,1),1), self.noise_cov, self.number_of_ensembles);
153+
154+
% method to calculate resolution D
155+
method = '3';
156+
if(method == '1')
157+
P_sqrtm = sqrtm(C);
158+
B = L * P_sqrtm;
159+
G = B' / (B * B' + self.noise_cov);
160+
w_t = 1 ./ sum(G.' .* B, 1)';
161+
D = w_t .* inv(P_sqrtm);
162+
elseif(method == '2')
163+
% complexity O(n^3)
164+
[Ur,Sr,Vr] = svd(C);
165+
Sr = diag(Sr);
166+
RNK = sum(Sr > (length(Sr) * eps(single(Sr(1)))));
167+
SIR = Vr(:,1:RNK) * diag(1./sqrt(Sr(1:RNK))) * Ur(:,1:RNK)'; % square root
168+
P_sqrtm = Vr(:,1:RNK) * diag(sqrt(Sr(1:RNK))) * Ur(:,1:RNK)';
169+
B = L * P_sqrtm;
170+
G = B' / (B * B' + self.noise_cov);
171+
w_t = 1 ./ sum(G.' .* B, 1)';
172+
D = w_t .* SIR;
173+
else
174+
D = speye(size(C));
175+
end
176+
% Update
177+
K = C * L' / (L * C * L' + self.noise_cov);
178+
self.prev_step_reconstruction = x_f + K *(f + v' - L*x_f);
179+
% x_ensemble = x_ensemble';
180+
mean_x = mean(self.prev_step_reconstruction,2);
181+
z_vec = D*mean_x;
182+
end
183+
184+
end % function

+inverse/@KalmanInverter/invert.m

+6-6
Original file line numberDiff line numberDiff line change
@@ -106,9 +106,9 @@
106106
%% KALMAN FILTER
107107
if strcmp(self.method_type,"Basic Kalman filter")
108108
% Prediction
109-
[x, P] = class_kf_predict(self);
109+
[x, P] = plugins.ClassKF.class_kf_predict(self);
110110
% Update
111-
[x, P] = kf_update(x, P, f, L, self.noise_cov);
111+
[x, P] = plugins.ClassKF.kf_update(x, P, f, L, self.noise_cov);
112112
if self.use_smoothing
113113
self.posterior_covs = [self.posterior_covs,gather(P)];
114114
end
@@ -117,9 +117,9 @@
117117
self.prev_step_posterior_cov = P;
118118
elseif strcmp(self.method_type,"Standardized Kalman filter")
119119
% Prediction
120-
[x, P] = class_kf_predict(self);
120+
[x, P] = plugins.ClassKF.class_kf_predict(self);
121121
% Update
122-
[x, P, ~, D] = kf_sL_update(x, gather(P), f, L, self.noise_cov);
122+
[x, P, ~, D] = plugins.ClassKF.kf_sL_update(x, gather(P), f, L, self.noise_cov);
123123
if self.use_smoothing
124124
self.posterior_covs = [self.posterior_covs,gather(P)];
125125
end
@@ -128,9 +128,9 @@
128128
self.prev_step_posterior_cov = P;
129129
elseif strcmp(self.method_type,"Approximated Standardized Kalman filter")
130130
% Prediction
131-
[x, P] = class_kf_predict(self);
131+
[x, P] = plugins.ClassKF.class_kf_predict(self);
132132
% Update
133-
[x, P, ~, D] = kf_sL_update_approx(x, P, f, L, self.noise_cov);
133+
[x, P, ~, D] = plugins.ClassKF.kf_sL_update_approx(x, P, f, L, self.noise_cov);
134134
if self.use_smoothing
135135
self.posterior_covs = [self.posterior_covs,gather(P)];
136136
end

0 commit comments

Comments
 (0)