Skip to content

Commit e868601

Browse files
authored
Pt3 force batch (evangelistalab#296)
* fix batching in old pt3 * rename the ignore memory keyword * change test case dsrg-mrpt3-9
1 parent b3baf98 commit e868601

File tree

5 files changed

+1443
-698
lines changed

5 files changed

+1443
-698
lines changed

forte/mrdsrg-spin-integrated/dsrg_mrpt3.cc

+23-25
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,9 @@ DSRG_MRPT3::DSRG_MRPT3(std::shared_ptr<RDMs> rdms, std::shared_ptr<SCFInfo> scf_
7171
DSRG_MRPT3::~DSRG_MRPT3() { cleanup(); }
7272

7373
void DSRG_MRPT3::startup() {
74+
enforce_batching_ = foptions_->get_bool("DSRG_MRPT3_BATCHED");
75+
ignore_memory_errors_ = foptions_->get_bool("IGNORE_MEMORY_ERRORS");
76+
7477
// lambda to print memory in good-looking unit
7578
auto to_XB = [](size_t nele, size_t type_size) {
7679
auto p = to_xb(nele, type_size);
@@ -92,10 +95,6 @@ void DSRG_MRPT3::startup() {
9295
// memory usage
9396
mem_total_ = static_cast<int64_t>(0.98 * psi::Process::environment.get_memory());
9497

95-
if (foptions_->get_bool("DSRG_MRPT3_BATCHED")) {
96-
mem_total_ = 0;
97-
}
98-
9998
std::vector<std::pair<std::string, std::string>> mem_info{
10099
{"Memory asigned", to_XB(mem_total_, 1)}};
101100

@@ -110,11 +109,6 @@ void DSRG_MRPT3::startup() {
110109
B_ = BTF_->build(tensor_type_, "B 3-idx", {"Lgg", "LGG"});
111110
fill_three_index_ints(B_);
112111

113-
/// B_.iterate([&](const std::vector<size_t>& i, const std::vector<SpinType>&,
114-
/// double& value) {
115-
//// value = ints_->three_integral(i[0], i[1], i[2]);
116-
// });
117-
118112
size_t sL = aux_mos_.size();
119113
nelement += sL * sg * sg;
120114
mem_info.push_back({"Memory used before DSRG", to_XB(nelement, sizeof(double))});
@@ -232,22 +226,23 @@ void DSRG_MRPT3::startup() {
232226
outfile->Printf("\n %-40s %15s", str_dim.first.c_str(), str_dim.second.c_str());
233227
}
234228

235-
if (mem_total_ < static_cast<int64_t>(nele_larger * sizeof(double)) and
236-
(not foptions_->get_bool("IGNORE_MEMORY_WARNINGS"))) {
229+
if (mem_total_ < static_cast<int64_t>(nele_larger * sizeof(double))) {
237230
outfile->Printf("\n\n Error: Not enough memory to compute DSRG-MRPT3 energy.");
238231
outfile->Printf("\n Minimum memory required: %s\n",
239232
to_XB(nele_larger, sizeof(double)).c_str());
240-
throw psi::PSIEXCEPTION("Not enough memory to compute DSRG-MRPT3 energy.");
233+
if (!ignore_memory_errors_)
234+
throw psi::PSIEXCEPTION("Not enough memory to compute DSRG-MRPT3 energy.");
241235
}
242236

243237
// Check memory for dipole moment
244238
size_t shp = sh * sp;
245239
size_t saa = sa * sa;
246240
int64_t mem_dipole = sizeof(double) * (6 * (sg * sg) + 9 * (shp * shp - saa * saa));
247-
if (mem_total_ < mem_dipole && do_dm_) {
241+
if (mem_total_ < mem_dipole and do_dm_) {
248242
outfile->Printf("\n\n Error: Not enough memory to compute DSRG-MRPT3 dipole.");
249243
outfile->Printf("\n Minimum memory required: %s\n", to_XB(mem_dipole, 1).c_str());
250-
throw psi::PSIEXCEPTION("Not enough memory to compute DSRG-MRPT3 dipole.");
244+
if (!ignore_memory_errors_)
245+
throw psi::PSIEXCEPTION("Not enough memory to compute DSRG-MRPT3 dipole.");
251246
}
252247
}
253248

@@ -592,7 +587,7 @@ double DSRG_MRPT3::compute_energy_pt3_1() {
592587
int64_t mem_max = sizeof(double) * (6 * (shp - saa) + 9 * (shp * shp - saa * saa));
593588
int64_t mem_min = sizeof(double) * (6 * (shp - saa) + 3 * (shp * shp - saa * saa));
594589

595-
if (mem_total_ < mem_min and (not foptions_->get_bool("IGNORE_MEMORY_WARNINGS"))) {
590+
if (mem_total_ < mem_min and !ignore_memory_errors_) {
596591
throw psi::PSIEXCEPTION("Not enough memory for compute_energy_pt3_1 in DSRG-MRPT3.");
597592
} else if (mem_total_ >= mem_max) {
598593

@@ -2492,8 +2487,8 @@ void DSRG_MRPT3::V_T2_C2_DF(BlockedTensor& B, BlockedTensor& T2, const double& a
24922487
sizeof(double) *
24932488
(2 * (p * h - a * a) + 3 * (p * p * h * h - a * a * a * a)); // local memory used in pt3_2
24942489
if (mem_total_ < 0 or static_cast<size_t>(mem_total_) < v * v * sizeof(double)) {
2495-
if (not foptions_->get_bool("IGNORE_MEMORY_WARNINGS")) {
2496-
outfile->Printf("\n Not enough memory for batching.");
2490+
outfile->Printf("\n Not enough memory for batching.");
2491+
if (!ignore_memory_errors_) {
24972492
throw psi::PSIEXCEPTION("Not enough memory for batching at DSRG-MRPT3 V_T2_C2_DF.");
24982493
}
24992494
}
@@ -2645,7 +2640,7 @@ void DSRG_MRPT3::V_T2_C2_DF(BlockedTensor& B, BlockedTensor& T2, const double& a
26452640
}
26462641

26472642
// particle-particle contractions
2648-
if (static_cast<int64_t>(nele_pp_max * sizeof(double)) < mem_total_) {
2643+
if (static_cast<int64_t>(nele_pp_max * sizeof(double)) < mem_total_ and !enforce_batching_) {
26492644

26502645
// set timer
26512646
start_ = std::chrono::system_clock::now();
@@ -2938,7 +2933,7 @@ void DSRG_MRPT3::V_T2_C2_DF(BlockedTensor& B, BlockedTensor& T2, const double& a
29382933
}
29392934

29402935
// compute exchange part
2941-
if (static_cast<int64_t>(nele_ph_max * sizeof(double)) < mem_total_) {
2936+
if (static_cast<int64_t>(nele_ph_max * sizeof(double)) < mem_total_ and !enforce_batching_) {
29422937
start_ = std::chrono::system_clock::now();
29432938
tt1_ = std::chrono::system_clock::to_time_t(start_);
29442939
if (profile_print_) {
@@ -3551,8 +3546,9 @@ void DSRG_MRPT3::V_T2_C2_DF_VV(BlockedTensor& B, BlockedTensor& T2, const double
35513546
outfile->Printf("\n Not enough memory for batching tensor "
35523547
"H2(%zu * %zu * %zu * %zu).",
35533548
sh0, sh1, sv, sv);
3554-
throw psi::PSIEXCEPTION("Not enough memory for batching at "
3555-
"DSRG-MRPT3 V_T2_C2_DF_VV.");
3549+
if (!ignore_memory_errors_)
3550+
throw psi::PSIEXCEPTION("Not enough memory for batching at "
3551+
"DSRG-MRPT3 V_T2_C2_DF_VV.");
35563552
}
35573553

35583554
// 1st virtual index
@@ -3933,8 +3929,9 @@ void DSRG_MRPT3::V_T2_C2_DF_VC_EX(BlockedTensor& B, BlockedTensor& T2, const dou
39333929
outfile->Printf("\n Not enough memory for batching tensor "
39343930
"H2(%zu * %zu * %zu * %zu).",
39353931
sq, ss, sc, sv);
3936-
throw psi::PSIEXCEPTION("Not enough memory for batching at DSRG-MRPT3 "
3937-
"V_T2_C2_DF_VC_EX.");
3932+
if (!ignore_memory_errors_)
3933+
throw psi::PSIEXCEPTION("Not enough memory for batching at DSRG-MRPT3 "
3934+
"V_T2_C2_DF_VC_EX.");
39383935
}
39393936

39403937
// fill the indices of sub virtuals
@@ -4261,8 +4258,9 @@ void DSRG_MRPT3::V_T2_C2_DF_VA_EX(BlockedTensor& B, BlockedTensor& T2, const dou
42614258
outfile->Printf("\n Not enough memory for batching tensor "
42624259
"H2(%zu * %zu * %zu * %zu).",
42634260
sq, ss, sa, sv);
4264-
throw psi::PSIEXCEPTION("Not enough memory for batching at DSRG-MRPT3 "
4265-
"V_T2_C2_DF_VA_EX.");
4261+
if (!ignore_memory_errors_)
4262+
throw psi::PSIEXCEPTION("Not enough memory for batching at DSRG-MRPT3 "
4263+
"V_T2_C2_DF_VA_EX.");
42664264
}
42674265

42684266
// fill the indices of sub virtuals

forte/mrdsrg-spin-integrated/dsrg_mrpt3.h

+4
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,10 @@ class DSRG_MRPT3 : public MASTER_DSRG {
9999

100100
/// Total memory left
101101
int64_t mem_total_;
102+
/// Enforce batching algorithm
103+
bool enforce_batching_;
104+
/// Ignore memory warnings and errors
105+
bool ignore_memory_errors_;
102106

103107
/// Fill up two-electron integrals
104108
void build_tei(BlockedTensor& V);

forte/register_forte_options.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -725,7 +725,7 @@ def register_dsrg_options(options):
725725

726726
options.add_bool("DSRG_MRPT3_BATCHED", False, "Force running the DSRG-MRPT3 code using the batched algorithm")
727727

728-
options.add_bool("IGNORE_MEMORY_WARNINGS", False, "Force running the DSRG-MRPT3 code using the batched algorithm")
728+
options.add_bool("IGNORE_MEMORY_ERRORS", False, "Continue running DSRG-MRPT3 even if memory exceeds")
729729

730730
options.add_int(
731731
"DSRG_DIIS_START", 2, "Iteration cycle to start adding error vectors for"

tests/methods/dsrg-mrpt3-9/input.dat

+38-36
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
1-
#! Generated using commit GITCOMMIT
1+
# Test DSRG-MRPT3 batching algorithm
22

33
import forte
44

5-
refmcscf = -99.406065223639
6-
refdsrgpt3 = -99.498903267935276
5+
refmcscf = -99.981029411803
6+
refdsrgpt3 = -100.246555424489
77

88
molecule HF{
99
0 1
@@ -13,44 +13,46 @@ molecule HF{
1313
}
1414

1515
set globals{
16-
basis 3-21g
17-
reference twocon
18-
scf_type pk
19-
e_convergence 8
20-
maxiter 100
21-
docc [3,0,1,1]
16+
reference rhf
17+
scf_type df
18+
basis cc-pvqz
19+
df_basis_scf cc-pvqz-jkfit
20+
df_basis_mp2 cc-pvqz-jkfit
21+
e_convergence 8
22+
maxiter 100
23+
docc [3,0,1,1]
24+
mcscf_type df
25+
restricted_docc [2,0,1,1]
26+
active [2,0,0,0]
27+
mcscf_diis_start 20
28+
mcscf_maxiter 60
29+
mcscf_e_convergence 10
2230
}
2331

24-
set mcscf{
25-
docc [2,0,1,1]
26-
socc [2,0,0,0]
27-
maxiter 1000
28-
level_shift 0.5
29-
d_convergence 10
30-
e_convergence 12
31-
}
32+
Emcscf, wfn = energy('casscf', return_wfn=True)
33+
compare_values(refmcscf,variable("CURRENT ENERGY"),9,"MCSCF energy")
3234

3335
set forte{
34-
active_space_solver fci
35-
correlation_solver dsrg-mrpt3
36-
dsrg_mrpt3_batched true
37-
ignore_memory_warnings true
38-
int_type cholesky
39-
cholesky_tolerance 1e-12
40-
frozen_docc [1,0,0,0]
41-
restricted_docc [1,0,1,1]
42-
active [2,0,0,0]
43-
root_sym 0
44-
nroot 1
45-
dsrg_s 1.0
46-
relax_ref once
47-
maxiter 100
48-
e_convergence 8
49-
semi_canonical false
36+
active_space_solver fci
37+
correlation_solver dsrg-mrpt3
38+
int_type df
39+
frozen_docc [1,0,0,0]
40+
restricted_docc [1,0,1,1]
41+
active [2,0,0,0]
42+
root_sym 0
43+
nroot 1
44+
dsrg_s 1.0
45+
relax_ref once
46+
maxiter 100
47+
e_convergence 8
48+
semi_canonical false
5049
}
5150

52-
Emcscf, wfn = energy('mcscf', return_wfn=True)
53-
compare_values(refmcscf,variable("CURRENT ENERGY"),10,"MCSCF energy") #TEST
51+
energy('forte', ref_wfn=wfn)
52+
compare_values(refdsrgpt3,variable("CURRENT ENERGY"),8,"DSRG-MRPT3 relaxed energy")
53+
54+
memory 1 gb
55+
set forte dsrg_mrpt3_batched true
5456

5557
energy('forte', ref_wfn=wfn)
56-
compare_values(refdsrgpt3,variable("CURRENT ENERGY"),8,"DSRG-MRPT3 relaxed energy") #TEST
58+
compare_values(refdsrgpt3,variable("CURRENT ENERGY"),8,"DSRG-MRPT3 relaxed energy")

0 commit comments

Comments
 (0)