Skip to content

Commit

Permalink
Removes unused fields from priors struct (#348)
Browse files Browse the repository at this point in the history
* Removes unused fields from priors struct

* Moves priors into problemStruct

* Addresses review comments
  • Loading branch information
DrPaulSharp authored Feb 18, 2025
1 parent e3c5909 commit de747af
Show file tree
Hide file tree
Showing 41 changed files with 92 additions and 174 deletions.
2 changes: 1 addition & 1 deletion 3rdParty/paramonte/processParamonteRuns.m
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
%
% controls = controlsClass();

[problemStruct,problemLimits,~,controls] = parseClassToStructs(problem,controls);
[problemStruct,problemLimits,controls] = parseClassToStructs(problem,controls);

[problemStruct,fitNames] = packParams(problemStruct,problemLimits);

Expand Down
2 changes: 1 addition & 1 deletion 3rdParty/paramonte/runParamonte.m
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
rng('default');

% Split problem using the routines from RAT..
[problemStruct,problemLimits,priors,controls] = parseClassToStructs(project,inputControls);
[problemStruct,problemLimits,controls] = parseClassToStructs(project,inputControls);

%controls.parallel = coderEnums.parallelOptions.Points;

Expand Down
4 changes: 2 additions & 2 deletions API/RAT.m
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@
% chain [M x nParams] double MCMC chains where M is the length of each chain
% =================== ==================== ===============

[problemStruct,problemLimits,priors,controls] = parseClassToStructs(project,controls);
[problemStruct,problemLimits,controls] = parseClassToStructs(project,controls);

% Set controls.calcSLD to 1 if we are doing customXY
switch lower(problemStruct.modelType)
Expand All @@ -78,7 +78,7 @@
end

tic
[problemStruct,result,bayesResults] = RATMain_mex(problemStruct,problemLimits,controls,priors);
[problemStruct,result,bayesResults] = RATMain_mex(problemStruct,problemLimits,controls);

if display
toc
Expand Down
6 changes: 3 additions & 3 deletions API/RATMain.m
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
function [problemStruct,result,bayesResults] = RATMain(problemStruct,problemLimits,controls,priors)
function [problemStruct,result,bayesResults] = RATMain(problemStruct,problemLimits,controls)
coderEnums.initialise()

if strcmpi(problemStruct.TF, coderEnums.calculationTypes.Domains)
Expand Down Expand Up @@ -29,12 +29,12 @@
if ~strcmpi(controls.display, coderEnums.displayOptions.Off)
triggerEvent(coderEnums.eventTypes.Message, sprintf('\nRunning Nested Sampler\n\n'));
end
[problemStruct,result,bayesResults] = runNestedSampler(problemStruct,problemLimits,controls,priors);
[problemStruct,result,bayesResults] = runNestedSampler(problemStruct,problemLimits,controls);
case coderEnums.procedures.Dream
if ~strcmpi(controls.display, coderEnums.displayOptions.Off)
triggerEvent(coderEnums.eventTypes.Message, sprintf('\nRunning DREAM\n\n'));
end
[problemStruct,result,bayesResults] = runDREAM(problemStruct,problemLimits,controls,priors);
[problemStruct,result,bayesResults] = runDREAM(problemStruct,problemLimits,controls);
otherwise
coderException(coderEnums.errorCodes.invalidOption, 'The procedure "%s" is not supported. The procedure must be one of "%s"', controls.procedure, strjoin(fieldnames(coderEnums.procedures), '", "'));
end
Expand Down
100 changes: 48 additions & 52 deletions API/parseClassToStructs.m
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
function [problemStruct,problemLimits,priors,controls] = parseClassToStructs(project,inputControls)
function [problemStruct,problemLimits,controls] = parseClassToStructs(project,inputControls)

% Breaks up the classes into the relevant structures for inputting into C

Expand All @@ -13,56 +13,6 @@
inputData{i} = [contrastData zeros(size(contrastData,1), 6-size(contrastData,2))];
end

%% Put the priors into their own array
priors.params = inputStruct.paramPriors;
priors.backgroundParams = inputStruct.backgroundParamPriors;
priors.scalefactors = inputStruct.scalefactorPriors;
priors.qzshifts = inputStruct.qzshiftPriors;
priors.bulkIns = inputStruct.bulkInPriors;
priors.bulkOuts = inputStruct.bulkOutPriors;
priors.resolutionParams = inputStruct.resolutionParamPriors;
if isa(project, 'domainsClass')
priors.domainRatios = inputStruct.domainRatioPriors;
else
priors.domainRatios = cell(0,1);
end

priorFields = fieldnames(priors);
totalNumber = 0;
for i=1:length(priorFields)
totalNumber = totalNumber + size(priors.(priorFields{i}), 1);
end

priorsCell = cell(totalNumber,4);
cellCount = 1;

for i=1:length(priorFields)
currentPrior = priorFields{i};
for j = 1:size(priors.(currentPrior), 1)
priorsCell{cellCount,1} = priors.(currentPrior){j}{1};

% Check prior type.....
thisType = priors.(currentPrior){j}{2};

if strcmpi(thisType, priorTypes.Uniform.value)
priorType = 1;
elseif strcmpi(thisType, priorTypes.Gaussian.value)
priorType = 2;
else
priorType = 3;
end
priorsCell{cellCount,2} = priorType;

priorsCell{cellCount,3} = priors.(currentPrior){j}{3};
priorsCell{cellCount,4} = priors.(currentPrior){j}{4};
cellCount = cellCount + 1;
end
end

priors.priorNames = priorsCell(:, 1);
priors.priorValues = cell2mat(priorsCell(:, 2:end));


%% Deal with backgrounds and resolutions

% Convert contrastBackgrounds to custom file/parameter indices
Expand Down Expand Up @@ -277,6 +227,52 @@
problemStruct.fitLimits = [];
problemStruct.otherLimits = [];

%% Put the priors into their fields

priorFields = {"paramPriors", "backgroundParamPriors", ...
"scalefactorPriors","qzshiftPriors", "bulkInPriors", ...
"bulkOutPriors", "resolutionParamPriors"};

if isa(project, 'domainsClass')
priorFields{end+1} = "domainRatioPriors";
end

totalNumber = 0;
for i=1:length(priorFields)
totalNumber = totalNumber + size(inputStruct.(priorFields{i}), 1);
end

priorsCell = cell(totalNumber,4);
cellCount = 1;

for i=1:length(priorFields)
currentPrior = priorFields{i};
for j = 1:size(inputStruct.(currentPrior), 1)
priorsCell{cellCount,1} = inputStruct.(currentPrior){j}{1};

% Check prior type
thisType = inputStruct.(currentPrior){j}{2};

if strcmpi(thisType, priorTypes.Uniform.value)
priorType = 1;
elseif strcmpi(thisType, priorTypes.Gaussian.value)
priorType = 2;
else
priorType = 3;
end
priorsCell{cellCount,2} = priorType;

priorsCell{cellCount,3} = inputStruct.(currentPrior){j}{3};
priorsCell{cellCount,4} = inputStruct.(currentPrior){j}{4};
cellCount = cellCount + 1;
end
end

problemStruct.priorNames = priorsCell(:, 1);
problemStruct.priorValues = cell2mat(priorsCell(:, 2:end));

%% Add structs for parameter names and fits

% Record lists of parameter names
problemStruct.names.params = inputStruct.paramNames;
problemStruct.names.backgroundParams = inputStruct.backgroundParamNames;
Expand All @@ -292,7 +288,7 @@
end
problemStruct.names.contrasts = inputStruct.contrastNames;

% Also need to deal with the checks...
% Record lists of parameter fits
problemStruct.checks.params = inputStruct.fitParam;
problemStruct.checks.backgroundParams = inputStruct.fitBackgroundParam;
problemStruct.checks.scalefactors = inputStruct.fitScalefactor;
Expand Down
74 changes: 4 additions & 70 deletions compile/fullCompile/makeCompileArgsFull.m
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
maxDataSize = 10000;

ARGS = cell(1,1);
ARGS{1} = cell(4,1);
ARGS{1} = cell(3,1);
ARGS_1_1 = struct;
ARGS_1_1.TF = coder.typeof('X',[1 maxArraySize],[0 1]);
ARGS_1_1.resample = coder.typeof(0,[1 maxArraySize],[0 1]);
Expand Down Expand Up @@ -63,6 +63,9 @@
ARGS_1_1.otherParams = coder.typeof(0,[1 maxArraySize],[0 1]);
ARGS_1_1.fitLimits = coder.typeof(0,[maxArraySize 2],[1 0]);
ARGS_1_1.otherLimits = coder.typeof(0,[maxArraySize 2],[1 0]);
ARG = coder.typeof('X',[1 maxArraySize],[0 1]);
ARGS_1_1.priorNames = coder.typeof({ARG}, [maxArraySize 1],[1 0]);
ARGS_1_1.priorValues = coder.typeof(0, [maxArraySize 3], [1 0]);
ARGS_1_1_names = struct;
ARG = coder.typeof('X',[1 maxArraySize],[0 1]);
ARGS_1_1_names.params = coder.typeof({ARG}, [1 maxArraySize],[0 1]);
Expand Down Expand Up @@ -127,74 +130,5 @@
ARGS_1_3.adaptPCR = coder.typeof(true);
ARGS_1_3.IPCFilePath = coder.typeof('X',[1 maxArraySize],[0 1]);
ARGS{1}{3} = coder.typeof(ARGS_1_3);
ARGS_1_4 = struct;
ARG_20 = cell([1 4]);
ARG_20{1} = coder.typeof('X',[1 maxArraySize],[0 1]);
ARG_20{2} = coder.typeof('X',[1 maxArraySize],[0 1]);
ARG_20{3} = coder.typeof(0);
ARG_20{4} = coder.typeof(0);
ARG_20 = coder.typeof(ARG_20,[1 4]);
ARG_20 = ARG_20.makeHeterogeneous();
ARGS_1_4.params = coder.typeof({ARG_20}, [maxArraySize 1],[1 0]);
ARG_21 = cell([1 4]);
ARG_21{1} = coder.typeof('X',[1 maxArraySize],[0 1]);
ARG_21{2} = coder.typeof('X',[1 maxArraySize],[0 1]);
ARG_21{3} = coder.typeof(0);
ARG_21{4} = coder.typeof(0);
ARG_21 = coder.typeof(ARG_21,[1 4]);
ARG_21 = ARG_21.makeHeterogeneous();
ARGS_1_4.backgroundParams = coder.typeof({ARG_21}, [maxArraySize 1],[1 0]);
ARG_22 = cell([1 4]);
ARG_22{1} = coder.typeof('X',[1 maxArraySize],[0 1]);
ARG_22{2} = coder.typeof('X',[1 maxArraySize],[0 1]);
ARG_22{3} = coder.typeof(0);
ARG_22{4} = coder.typeof(0);
ARG_22 = coder.typeof(ARG_22,[1 4]);
ARG_22 = ARG_22.makeHeterogeneous();
ARGS_1_4.scalefactors = coder.typeof({ARG_22}, [maxArraySize 1],[1 0]);
ARG_23 = cell([1 4]);
ARG_23{1} = coder.typeof('X',[1 maxArraySize],[0 1]);
ARG_23{2} = coder.typeof('X',[1 maxArraySize],[0 1]);
ARG_23{3} = coder.typeof(0);
ARG_23{4} = coder.typeof(0);
ARG_23 = coder.typeof(ARG_23,[1 4]);
ARG_23 = ARG_23.makeHeterogeneous();
ARGS_1_4.qzshifts = coder.typeof({ARG_23}, [maxArraySize 1],[1 0]);
ARG_24 = cell([1 4]);
ARG_24{1} = coder.typeof('X',[1 maxArraySize],[0 1]);
ARG_24{2} = coder.typeof('X',[1 maxArraySize],[0 1]);
ARG_24{3} = coder.typeof(0);
ARG_24{4} = coder.typeof(0);
ARG_24 = coder.typeof(ARG_24,[1 4]);
ARG_24 = ARG_24.makeHeterogeneous();
ARGS_1_4.bulkIns = coder.typeof({ARG_24}, [maxArraySize 1],[1 0]);
ARG_25 = cell([1 4]);
ARG_25{1} = coder.typeof('X',[1 maxArraySize],[0 1]);
ARG_25{2} = coder.typeof('X',[1 maxArraySize],[0 1]);
ARG_25{3} = coder.typeof(0);
ARG_25{4} = coder.typeof(0);
ARG_25 = coder.typeof(ARG_25,[1 4]);
ARG_25 = ARG_25.makeHeterogeneous();
ARGS_1_4.bulkOuts = coder.typeof({ARG_25}, [maxArraySize 1],[1 0]);
ARG_26 = cell([1 4]);
ARG_26{1} = coder.typeof('X',[1 maxArraySize],[0 1]);
ARG_26{2} = coder.typeof('X',[1 maxArraySize],[0 1]);
ARG_26{3} = coder.typeof(0);
ARG_26{4} = coder.typeof(0);
ARG_26 = coder.typeof(ARG_26,[1 4]);
ARG_26 = ARG_26.makeHeterogeneous();
ARGS_1_4.resolutionParams = coder.typeof({ARG_26}, [maxArraySize 1],[1 0]);
ARG_27 = cell([1 4]);
ARG_27{1} = coder.typeof('X',[1 maxArraySize],[0 1]);
ARG_27{2} = coder.typeof('X',[1 maxArraySize],[0 1]);
ARG_27{3} = coder.typeof(0);
ARG_27{4} = coder.typeof(0);
ARG_27 = coder.typeof(ARG_27,[1 4]);
ARG_27 = ARG_27.makeHeterogeneous();
ARGS_1_4.domainRatios = coder.typeof({ARG_27}, [maxArraySize 1],[1 0]);
ARG_28 = coder.typeof('X',[1 Inf],[0 1]);
ARGS_1_4.priorNames = coder.typeof({ARG_28}, [maxArraySize 1],[1 0]);
ARGS_1_4.priorValues = coder.typeof(0, [maxArraySize 3], [1 0]);
ARGS{1}{4} = coder.typeof(ARGS_1_4);

end
3 changes: 3 additions & 0 deletions compile/reflectivityCalculation/makeCompileArgs.m
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,9 @@
ARGS_1_1.otherParams = coder.typeof(0,[1 maxArraySize],[0 1]);
ARGS_1_1.fitLimits = coder.typeof(0,[maxArraySize 2],[1 0]);
ARGS_1_1.otherLimits = coder.typeof(0,[maxArraySize 2],[1 0]);
ARG = coder.typeof('X',[1 maxArraySize],[0 1]);
ARGS_1_1.priorNames = coder.typeof({ARG}, [maxArraySize 1],[1 0]);
ARGS_1_1.priorValues = coder.typeof(0, [maxArraySize 3], [1 0]);
ARGS_1_1_names = struct;
ARG = coder.typeof('X',[1 maxArraySize],[0 1]);
ARGS_1_1_names.params = coder.typeof({ARG}, [1 maxArraySize],[0 1]);
Expand Down
Binary file not shown.
5 changes: 3 additions & 2 deletions minimisers/DREAM/runDREAM.m
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
function [outProblemStruct,result,bayesResults] = runDREAM(problemStruct,problemLimits,controls,priors)
function [outProblemStruct,result,bayesResults] = runDREAM(problemStruct,problemLimits,controls)


% Make an empty struct for bayesResults to hold the outputs of the
Expand Down Expand Up @@ -33,7 +33,8 @@
[problemStruct,fitParamNames] = packParams(problemStruct,problemLimits);

% Get the priors for the fitted parameters...
priorList = getFittedPriors(fitParamNames,priors,problemStruct.fitLimits);
priorList = getFittedPriors(fitParamNames, problemStruct.priorNames, ...
problemStruct.priorValues, problemStruct.fitLimits);

% Put all the RAT parameters together into one array...
ratInputs.problemStruct = problemStruct;
Expand Down
9 changes: 5 additions & 4 deletions minimisers/NS/runNestedSampler.m
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
function [problemStruct,result,bayesResults] = runNestedSampler(problemStruct,problemLimits,controls,inPriors)
function [problemStruct,result,bayesResults] = runNestedSampler(problemStruct,problemLimits,controls)

[problemStruct,fitNames] = packParams(problemStruct,problemLimits);

Expand All @@ -17,10 +17,11 @@
end
bayesResults = makeEmptyBayesResultsStruct(numberOfContrasts, domains, numberOfChains);

%Deal with priors.
priorList = getFittedPriors(fitNames,inPriors,problemStruct.fitLimits);
% Deal with priors.
priorList = getFittedPriors(fitNames, problemStruct.priorNames, ...
problemStruct.priorValues, problemStruct.fitLimits);

%Tuning Parameters
% Tuning Parameters
model.ssfun = @nsIntraFun;
nLive = controls.nLive;
tolerance = controls.nsTolerance;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,4 @@
function priorFitList = getFittedPriors(paramNames,priors,fitLimits)

% Get the list of all the priors..
priorNames = priors.priorNames;
priorValues = priors.priorValues;
function priorFitList = getFittedPriors(paramNames,priorNames,priorValues,fitLimits)

% Find the values for fitParams
numberOfParams = length(paramNames);
Expand Down
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file modified tests/domainsTFReflectivityCalculation/domainsCustomXYInputs.mat
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Loading

0 comments on commit de747af

Please sign in to comment.