Skip to content

Commit 17f9527

Browse files
authored
Add files via upload
Updated the way forgetting rates are implemented. The forgetting rate parameter (omega) now takes values between 0 and 1, where higher values lead to greater forgetting. The updated implementation also only applies forgetting to concentration parameter values added to the initial prior values. This prevents the concentration parameters from moving to values that are implausibly low, which can cause numerical problems when running the code.
1 parent 2e7bd86 commit 17f9527

3 files changed

+58
-41
lines changed

Simplified_simulation_script.m

+15-15
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
% Application to Empirical Data
77

88
% By: Ryan Smith, Karl J. Friston, Christopher J. Whyte
9-
9+
% UPDATED: 8/28/2024 (modified forgetting rate implementation)
1010
rng('shuffle')
1111
close all
1212
clear
@@ -36,11 +36,11 @@
3636
% function starting on line 810. This includes, among
3737
% others (similar to in the main tutorial script):
3838

39-
% prior beliefs about context (d): alter line 866
39+
% prior beliefs about context (d): alter line 876
4040

41-
% beliefs about hint accuracy in the likelihood (a): alter lines 986-988
41+
% beliefs about hint accuracy in the likelihood (a): alter lines 996-998
4242

43-
% to adjust habits (e), alter line 1145
43+
% to adjust habits (e), alter line 1155
4444

4545
%% Specify Generative Model
4646

@@ -336,7 +336,7 @@
336336
for modality = 1:NumModalities
337337
% prior preferences about outcomes
338338
predictive_observations_posterior = cell_md_dot(a{modality},Expected_states(:)); %posterior over observations
339-
Gintermediate(policy) = Gintermediate(policy) + predictive_observations_posterior'*(C{modality}(:,timestep));
339+
Gintermediate(policy) = Gintermediate(policy) + predictive_observations_posterior'*(C{modality}(:,t));
340340

341341
% Bayesian surprise about parameters
342342
if isfield(MDP,'a')
@@ -445,7 +445,7 @@
445445
a_learning = spm_cross(a_learning,BMA_states{factor}(:,t));
446446
end
447447
a_learning = a_learning.*(MDP.a{modality} > 0);
448-
MDP.a{modality} = MDP.a{modality}*omega + a_learning*eta;
448+
MDP.a{modality} = (MDP.a{modality}-MDP.a_0{modality})*(1-omega) + MDP.a_0{modality} + a_learning*eta;
449449
end
450450
end
451451
end
@@ -454,13 +454,13 @@
454454
if isfield(MDP,'d')
455455
for factor = 1:NumFactors
456456
i = MDP.d{factor} > 0;
457-
MDP.d{factor}(i) = omega*MDP.d{factor}(i) + eta*BMA_states{factor}(i,1);
457+
MDP.d{factor}(i) = (1-omega)*(MDP.d{factor}(i)-MDP.d_0{factor}(i)) + MDP.d_0{factor}(i) + eta*BMA_states{factor}(i,1);
458458
end
459459
end
460460

461461
% policies e (habits)
462462
if isfield(MDP,'e')
463-
MDP.e = omega*MDP.e + eta*policy_posterior(:,T);
463+
MDP.e = (1-omega)*(MDP.e - MDP.e_0) + MDP.e_0 + eta*policy_posterior(:,T);
464464
end
465465

466466
% Free energy of concentration parameters
@@ -1163,9 +1163,9 @@
11631163
eta = 1; % Default (maximum) learning rate
11641164

11651165
% Omega: forgetting rate (0-1) controlling the magnitude of reduction in concentration
1166-
% parameter values after each trial (if learning is enabled).
1166+
% parameter values after each trial (if learning is enabled). NOTE THE FORM OF FORGETTING IMPLEMENTED HERE IS MODIFIED FROM THE DESCRIPTION IN THE PUBLISHED TUTORIAL FOR IMPROVED PERFORMANCE.
11671167

1168-
omega = 1; % Default value indicating there is no forgetting (values < 1 indicate forgetting)
1168+
omega = 0; % Default value indicating there is no forgetting (values approaching 1 indicate forgetting)
11691169

11701170
% Beta: Expected precision of expected free energy (G) over policies (a
11711171
% positive value, with higher values indicating lower expected precision).
@@ -1194,13 +1194,13 @@
11941194
mdp.B = B; % transition probabilities
11951195
mdp.C = C; % preferred states
11961196
mdp.D = D; % priors over initial states
1197-
mdp.d = d; % enable learning priors over initial states
1198-
1197+
mdp.d = d; mdp.d_0 = d; % enable learning priors over initial states
1198+
% d_0 is floor value for forgetting
11991199
if Gen_model == 1
12001200
mdp.E = E; % prior over policies
12011201
elseif Gen_model == 2
1202-
mdp.a = a; % enable learning state-outcome mappings
1203-
mdp.e = e; % enable learning of prior over policies
1202+
mdp.a = a; mdp.a_0 = a; % enable learning state-outcome mappings and set floor value for forgetting (a_0)
1203+
mdp.e = e; mdp.e_0 = e; % enable learning of prior over policies and set floor value for forgetting (e_0)
12041204
end
12051205

12061206
mdp.eta = eta; % learning rate
@@ -1225,4 +1225,4 @@
12251225

12261226
MDP = mdp;
12271227

1228-
end
1228+
end

Step_by_Step_AI_Guide.m

+30-13
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
% Application to Empirical Data
55

66
% By: Ryan Smith, Karl J. Friston, Christopher J. Whyte
7+
% UPDATED: 8/28/2024 (modified forgetting rate implementation)
78

89
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
910

@@ -48,7 +49,7 @@
4849
% To reproduce fig. 11, use values of 3 or 4 (with Sim = 3)
4950
% This will have no effect on Sim = 4 or Sim = 5
5051

51-
Sim = 1;
52+
Sim = 2;
5253

5354
% When Sim = 5, if PEB = 1 the script will run simulated group-level
5455
% (Parametric Empirical Bayes) analyses.
@@ -358,7 +359,7 @@
358359

359360
% Note that, expanded out, this means that the other C-matrices will be:
360361

361-
% C{1} = [0 0 0; % No Hint
362+
% C{1} = [0 0 0; % No Hint
362363
% 0 0 0; % Machine-Left Hint
363364
% 0 0 0]; % Machine-Right Hint
364365
%
@@ -376,7 +377,18 @@
376377
% This will not be simulated here. However, this works by increasing the
377378
% preference magnitude for an outcome each time that outcome is observed.
378379
% The assumption here is that preferences naturally increase for entering
379-
% situations that are more familiar.
380+
% situations that are more familiar. To do so, you can specify starting
381+
% concentration parameters. For example:
382+
383+
% c{1} = zeros(No(1),T); % Hints
384+
% c{2} = zeros(No(2),T); % Wins/Losses
385+
% c{3} = zeros(No(3),T); % Observed Behaviors
386+
%
387+
% c{2}(:,:) = [1 1 1 ; % Null
388+
% 1 0 0.5; % Loss
389+
% 1 2 1.5]; % win
390+
391+
% NOTE: These values must be non-negative; higher values = more preferred
380392

381393
% Allowable policies: U or V.
382394
%==========================================================================
@@ -461,12 +473,17 @@
461473
% degree to which newer experience can 'over-write' what has been learned
462474
% from older experiences. It is adaptive in environments where the true
463475
% parameters in the generative process (priors, likelihoods, etc.) can
464-
% change over time. A low value for omega can be seen as a prior that the
476+
% change over time. A high value for omega can be seen as a prior that the
465477
% world is volatile and that contingencies change over time.
466478

467-
omega = 1; % By default we here set this to 1 (indicating no forgetting,
479+
omega = 0.0; % By default we here set this to 0 (indicating no forgetting,
468480
% but try changing its value to see how it affects model behavior.
469-
% Values below 1 indicate greater rates of forgetting.
481+
% Values approaching 1 indicate greater rates of forgetting.
482+
% NOTE: Trial 1 concentration parameter values are set as
483+
% floor values (forgetting cannot reduce counts below those
484+
% values - THIS IS MODIFIED FROM THE PUBLISHED TUTORIAL VERSION
485+
% SO THAT CONCENTRATION PARAMETERS ABOVE THE FLOOR VALUE
486+
% ARE MULTIPLIED BY 1-OMEGA)
470487

471488
% Beta: Expected precision of expected free energy (G) over policies (a
472489
% positive value, with higher values indicating lower expected precision).
@@ -577,8 +594,8 @@
577594
mdp.C = C; % preferred states
578595
mdp.D = D; % priors over initial states
579596

580-
mdp.d = d; % enable learning priors over initial states
581-
597+
mdp.d = d; mdp.d_0 = mdp.d; % enable learning priors over initial states
598+
% and set lower bound on concentration paramaters (d_0)
582599
mdp.eta = eta; % learning rate
583600
mdp.omega = omega; % forgetting rate
584601
mdp.alpha = alpha; % action precision
@@ -591,10 +608,10 @@
591608
% mdp.E = E;
592609

593610
% or learning other parameters:
594-
% mdp.a = a;
595-
% mdp.b = b;
596-
% mdp.c = c;
597-
% mdp.e = e;
611+
% mdp.a = a; mdp.a_0 = mdp.a;
612+
% mdp.b = b; mdp.b_0 = mdp.b;
613+
% mdp.c = c; mdp.c_0 = mdp.c; clear mdp.C = C;
614+
% mdp.e = e; mdp.e_0 = mdp.e;
598615

599616
% or specifying true states or outcomes:
600617

@@ -1425,4 +1442,4 @@
14251442
% now build a generative model of a task, run simulations, assess parameter
14261443
% recoverability, do bayesian model comparison, and do hierarchical
14271444
% bayesian group analyses. See the main text for further explanation of
1428-
% other aspects of these steps.
1445+
% other aspects of these steps.

spm_MDP_VB_X_tutorial.m

+13-13
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
function [MDP] = spm_MDP_VB_X_tutorial(MDP,OPTIONS)
2+
% UPDATED: 8/28/2024 (modified forgetting rate implementation)
23

34
% active inference and learning using variational message passing
45
% FORMAT [MDP] = spm_MDP_VB_X_tutorial(MDP,OPTIONS)
@@ -149,7 +150,7 @@
149150

150151
% check MDP specification
151152
%--------------------------------------------------------------------------
152-
MDP = spm_MDP_check(MDP);
153+
MDP = spm_MDP_check(MDP);
153154

154155
% handle multiple trials, ensuring parameters (and posteriors) are updated
155156
%==========================================================================
@@ -225,7 +226,7 @@
225226
try, omega = MDP(1).omega; catch, omega = 1; end % forgetting rate
226227
try, tau = MDP(1).tau; catch, tau = 4; end % update time constant
227228
try, chi = MDP(1).chi; catch, chi = 1/64; end % Occam window updates
228-
try, erp = MDP(1).erp; catch, erp = 4; end % update reset
229+
try, erp = MDP(1).erp; catch, erp = 4; end % update reset
229230

230231
% preclude precision updates for moving policies
231232
%--------------------------------------------------------------------------
@@ -258,8 +259,8 @@
258259
end
259260
for g = 1:Ng(m)
260261
No(m,g) = size(MDP(m).A{g},1); % number of outcomes
261-
end
262-
262+
end
263+
263264
% parameters of generative model and policies
264265
%======================================================================
265266

@@ -306,7 +307,6 @@
306307
sB{m,f}(:,:,j) = spm_norm(MDP(m).B{f}(:,:,j) );
307308
rB{m,f}(:,:,j) = spm_norm(MDP(m).B{f}(:,:,j)');
308309
end
309-
310310
end
311311

312312
% prior concentration paramters for complexity
@@ -341,7 +341,7 @@
341341
% priors over policies - concentration parameters
342342
%----------------------------------------------------------------------
343343
if isfield(MDP,'e')
344-
E{m} = spm_norm(MDP(m).e);
344+
E{m} = spm_norm(MDP(m).e);
345345
elseif isfield(MDP,'E')
346346
E{m} = spm_norm(MDP(m).E);
347347
else
@@ -377,6 +377,7 @@
377377
end
378378
end
379379
C{m,g} = spm_log(spm_softmax(C{m,g}));
380+
380381
end
381382

382383
% initialise posterior expectations of hidden states
@@ -1123,7 +1124,7 @@
11231124
da = spm_cross(da,X{m,f}(:,t));
11241125
end
11251126
da = da.*(MDP(m).a{g} > 0);
1126-
MDP(m).a{g} = MDP(m).a{g}*omega + da*eta;
1127+
MDP(m).a{g} = (MDP(m).a{g}-MDP(m).a_0{g})*(1-omega) + MDP(m).a_0{g} + da*eta;
11271128
end
11281129
end
11291130

@@ -1135,7 +1136,7 @@
11351136
v = V{m}(t - 1,k,f);
11361137
db = u{m}(k,t)*x{m,f}(:,t,k)*x{m,f}(:,t - 1,k)';
11371138
db = db.*(MDP(m).b{f}(:,:,v) > 0);
1138-
MDP(m).b{f}(:,:,v) = MDP(m).b{f}(:,:,v)*omega + db*eta;
1139+
MDP(m).b{f}(:,:,v) = (MDP(m).b{f}(:,:,v)-MDP(m).b_0{f}(:,:,v))*(1-omega) + MDP(m).b_0{f}(:,:,v) + db*eta;
11391140
end
11401141
end
11411142
end
@@ -1147,10 +1148,10 @@
11471148
dc = O{m}{g,t};
11481149
if size(MDP(m).c{g},2) > 1
11491150
dc = dc.*(MDP(m).c{g}(:,t) > 0);
1150-
MDP(m).c{g}(:,t) = MDP(m).c{g}(:,t)*omega + dc*eta;
1151+
MDP(m).c{g}(:,t) = (MDP(m).c{g}(:,t)-MDP(m).c_0{g}(:,t))*(1-omega) + MDP(m).c_0{g}(:,t) + dc*eta;
11511152
else
11521153
dc = dc.*(MDP(m).c{g}>0);
1153-
MDP(m).c{g} = MDP(m).c{g}*omega + dc*eta;
1154+
MDP(m).c{g} = (MDP(m).c{g}-c_0{g})*(1-omega) + c_0{g} + dc*eta;
11541155
end
11551156
end
11561157
end
@@ -1161,14 +1162,14 @@
11611162
if isfield(MDP,'d')
11621163
for f = 1:Nf(m)
11631164
i = MDP(m).d{f} > 0;
1164-
MDP(m).d{f}(i) = MDP(m).d{f}(i)*omega + X{m,f}(i,1)*eta;
1165+
MDP(m).d{f}(i) = (MDP(m).d{f}(i)-MDP(m).d_0{f}(i))*(1-omega) + MDP(m).d_0{f}(i) + X{m,f}(i,1)*eta;
11651166
end
11661167
end
11671168

11681169
% policies
11691170
%----------------------------------------------------------------------
11701171
if isfield(MDP,'e')
1171-
MDP(m).e = MDP(m).e*omega + eta*u{m}(:,T);
1172+
MDP(m).e = (MDP(m).e-MDP(m).e_0)*(1-omega) + MDP(m).e_0 + eta*u{m}(:,T);
11721173
end
11731174

11741175
% (negative) free energy of parameters (complexity): outcome specific
@@ -1682,4 +1683,3 @@
16821683

16831684
end
16841685

1685-

0 commit comments

Comments
 (0)