Skip to content

Commit 2002cd7

Browse files
authored
Merge pull request #5373 from ylee88/sposet_templates
More template classes derived from `SPOSetT<T>`
2 parents 094b0c8 + 60101d2 commit 2002cd7

11 files changed

+170
-84
lines changed

src/Estimators/OneBodyDensityMatrices.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ class OneBodyDensityMatrices : public OperatorEstBase
8484
/** @} */
8585

8686
//data members \todo analyze lifecycles allocation optimization or state?
87-
CompositeSPOSet basis_functions_;
87+
CompositeSPOSet<Value> basis_functions_;
8888
Vector<Value> basis_values_;
8989
Vector<Value> basis_norms_;
9090
Vector<Grad> basis_gradients_;

src/Estimators/tests/test_MagnetizationDensity.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -368,8 +368,8 @@ TEST_CASE("MagnetizationDensity::IntegrationTest", "[estimators]")
368368
mup(1, iorb) = uprow1[iorb];
369369
mdn(1, iorb) = dnrow1[iorb];
370370
}
371-
auto spo_up = std::make_unique<ConstantSPOSet>("ConstantUpSet", nelec, norb);
372-
auto spo_dn = std::make_unique<ConstantSPOSet>("ConstantDnSet", nelec, norb);
371+
auto spo_up = std::make_unique<ConstantSPOSet<Value>>("ConstantUpSet", nelec, norb);
372+
auto spo_dn = std::make_unique<ConstantSPOSet<Value>>("ConstantDnSet", nelec, norb);
373373

374374
spo_up->setRefVals(mup);
375375
spo_dn->setRefVals(mdn);

src/QMCHamiltonians/DensityMatrices1B.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ class DensityMatrices1B : public OperatorBase
6565

6666
//data members
6767
bool energy_mat;
68-
CompositeSPOSet basis_functions;
68+
CompositeSPOSet<Value_t> basis_functions;
6969
ValueVector basis_values;
7070
ValueVector basis_norms;
7171
GradVector basis_gradients;

src/QMCWaveFunctions/CompositeSPOSet.cpp

+56-37
Original file line numberDiff line numberDiff line change
@@ -39,23 +39,27 @@ inline void insert_columns(const MAT1& small, MAT2& big, int offset_c)
3939
}
4040
} // namespace MatrixOperators
4141

42-
CompositeSPOSet::CompositeSPOSet(const std::string& my_name) : SPOSet(my_name)
42+
template<typename T>
43+
CompositeSPOSet<T>::CompositeSPOSet(const std::string& my_name) : SPOSetT<T>(my_name)
4344
{
44-
OrbitalSetSize = 0;
45+
SPOSet::OrbitalSetSize = 0;
4546
component_offsets.reserve(4);
4647
}
4748

48-
CompositeSPOSet::CompositeSPOSet(const CompositeSPOSet& other) : SPOSet(other)
49+
template<typename T>
50+
CompositeSPOSet<T>::CompositeSPOSet(const CompositeSPOSet& other) : SPOSet(other)
4951
{
5052
for (auto& element : other.components)
5153
{
5254
this->add(element->makeClone());
5355
}
5456
}
5557

56-
CompositeSPOSet::~CompositeSPOSet() = default;
58+
template<typename T>
59+
CompositeSPOSet<T>::~CompositeSPOSet() = default;
5760

58-
void CompositeSPOSet::add(std::unique_ptr<SPOSet> component)
61+
template<typename T>
62+
void CompositeSPOSet<T>::add(std::unique_ptr<SPOSet> component)
5963
{
6064
if (components.empty())
6165
component_offsets.push_back(0); //add 0
@@ -67,11 +71,12 @@ void CompositeSPOSet::add(std::unique_ptr<SPOSet> component)
6771
component_laplacians.emplace_back(norbs);
6872
component_spin_gradients.emplace_back(norbs);
6973

70-
OrbitalSetSize += norbs;
71-
component_offsets.push_back(OrbitalSetSize);
74+
SPOSet::OrbitalSetSize += norbs;
75+
component_offsets.push_back(SPOSet::OrbitalSetSize);
7276
}
7377

74-
void CompositeSPOSet::report()
78+
template<typename T>
79+
void CompositeSPOSet<T>::report()
7580
{
7681
app_log() << "CompositeSPOSet" << std::endl;
7782
app_log() << " ncomponents = " << components.size() << std::endl;
@@ -83,9 +88,11 @@ void CompositeSPOSet::report()
8388
}
8489
}
8590

86-
std::unique_ptr<SPOSet> CompositeSPOSet::makeClone() const { return std::make_unique<CompositeSPOSet>(*this); }
91+
template<typename T>
92+
std::unique_ptr<SPOSetT<T>> CompositeSPOSet<T>::makeClone() const { return std::make_unique<CompositeSPOSet>(*this); }
8793

88-
void CompositeSPOSet::evaluateValue(const ParticleSet& P, int iat, ValueVector& psi)
94+
template<typename T>
95+
void CompositeSPOSet<T>::evaluateValue(const ParticleSet& P, int iat, ValueVector& psi)
8996
{
9097
int n = 0;
9198
for (int c = 0; c < components.size(); ++c)
@@ -98,7 +105,8 @@ void CompositeSPOSet::evaluateValue(const ParticleSet& P, int iat, ValueVector&
98105
}
99106
}
100107

101-
void CompositeSPOSet::evaluateVGL(const ParticleSet& P, int iat, ValueVector& psi, GradVector& dpsi, ValueVector& d2psi)
108+
template<typename T>
109+
void CompositeSPOSet<T>::evaluateVGL(const ParticleSet& P, int iat, ValueVector& psi, GradVector& dpsi, ValueVector& d2psi)
102110
{
103111
int n = 0;
104112
for (int c = 0; c < components.size(); ++c)
@@ -115,12 +123,13 @@ void CompositeSPOSet::evaluateVGL(const ParticleSet& P, int iat, ValueVector& ps
115123
}
116124
}
117125

118-
void CompositeSPOSet::evaluateVGL_spin(const ParticleSet& P,
119-
int iat,
120-
ValueVector& psi,
121-
GradVector& dpsi,
122-
ValueVector& d2psi,
123-
ValueVector& dspin_psi)
126+
template<typename T>
127+
void CompositeSPOSet<T>::evaluateVGL_spin(const ParticleSet& P,
128+
int iat,
129+
ValueVector& psi,
130+
GradVector& dpsi,
131+
ValueVector& d2psi,
132+
ValueVector& dspin_psi)
124133
{
125134
int n = 0;
126135
for (int c = 0; c < components.size(); ++c)
@@ -139,12 +148,13 @@ void CompositeSPOSet::evaluateVGL_spin(const ParticleSet& P,
139148
}
140149
}
141150

142-
void CompositeSPOSet::evaluate_notranspose(const ParticleSet& P,
143-
int first,
144-
int last,
145-
ValueMatrix& logdet,
146-
GradMatrix& dlogdet,
147-
ValueMatrix& d2logdet)
151+
template<typename T>
152+
void CompositeSPOSet<T>::evaluate_notranspose(const ParticleSet& P,
153+
int first,
154+
int last,
155+
ValueMatrix& logdet,
156+
GradMatrix& dlogdet,
157+
ValueMatrix& d2logdet)
148158
{
149159
const int nat = last - first;
150160
for (int c = 0; c < components.size(); ++c)
@@ -161,12 +171,13 @@ void CompositeSPOSet::evaluate_notranspose(const ParticleSet& P,
161171
}
162172
}
163173

164-
void CompositeSPOSet::evaluate_notranspose(const ParticleSet& P,
165-
int first,
166-
int last,
167-
ValueMatrix& logdet,
168-
GradMatrix& dlogdet,
169-
HessMatrix& grad_grad_logdet)
174+
template<typename T>
175+
void CompositeSPOSet<T>::evaluate_notranspose(const ParticleSet& P,
176+
int first,
177+
int last,
178+
ValueMatrix& logdet,
179+
GradMatrix& dlogdet,
180+
HessMatrix& grad_grad_logdet)
170181
{
171182
const int nat = last - first;
172183
for (int c = 0; c < components.size(); ++c)
@@ -183,13 +194,14 @@ void CompositeSPOSet::evaluate_notranspose(const ParticleSet& P,
183194
}
184195
}
185196

186-
void CompositeSPOSet::evaluate_notranspose(const ParticleSet& P,
187-
int first,
188-
int last,
189-
ValueMatrix& logdet,
190-
GradMatrix& dlogdet,
191-
HessMatrix& grad_grad_logdet,
192-
GGGMatrix& grad_grad_grad_logdet)
197+
template<typename T>
198+
void CompositeSPOSet<T>::evaluate_notranspose(const ParticleSet& P,
199+
int first,
200+
int last,
201+
ValueMatrix& logdet,
202+
GradMatrix& dlogdet,
203+
HessMatrix& grad_grad_logdet,
204+
GGGMatrix& grad_grad_grad_logdet)
193205
{
194206
not_implemented("evaluate_notranspose(P,first,last,logdet,dlogdet,ddlogdet,dddlogdet)");
195207
}
@@ -204,7 +216,7 @@ std::unique_ptr<SPOSet> CompositeSPOSetBuilder::createSPOSetFromXML(xmlNodePtr c
204216
return nullptr;
205217
}
206218

207-
auto spo_now = std::make_unique<CompositeSPOSet>(getXMLAttributeValue(cur, "name"));
219+
auto spo_now = std::make_unique<CompositeSPOSet<ValueType>>(getXMLAttributeValue(cur, "name"));
208220
for (int i = 0; i < spolist.size(); ++i)
209221
{
210222
const SPOSet* spo = sposet_builder_factory_.getSPOSet(spolist[i]);
@@ -219,4 +231,11 @@ std::unique_ptr<SPOSet> CompositeSPOSetBuilder::createSPOSet(xmlNodePtr cur, SPO
219231
return createSPOSetFromXML(cur);
220232
}
221233

234+
#if !defined(MIXED_PRECISION)
235+
template class CompositeSPOSet<double>;
236+
template class CompositeSPOSet<std::complex<double>>;
237+
#endif
238+
template class CompositeSPOSet<float>;
239+
template class CompositeSPOSet<std::complex<float>>;
240+
222241
} // namespace qmcplusplus

src/QMCWaveFunctions/CompositeSPOSet.h

+12-1
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,20 @@
2222

2323
namespace qmcplusplus
2424
{
25-
class CompositeSPOSet : public SPOSet
25+
26+
template<typename T>
27+
class CompositeSPOSet : public SPOSetT<T>
2628
{
2729
public:
30+
using SPOSet = SPOSetT<T>;
31+
32+
using ValueVector = typename SPOSet::ValueVector;
33+
using ValueMatrix = typename SPOSet::ValueMatrix;
34+
using GradVector = typename SPOSet::GradVector;
35+
using GradMatrix = typename SPOSet::GradMatrix;
36+
using HessMatrix = typename SPOSet::HessMatrix;
37+
using GGGMatrix = typename SPOSet::GGGMatrix;
38+
2839
///component SPOSets
2940
std::vector<std::unique_ptr<SPOSet>> components;
3041
///temporary storage for values

src/QMCWaveFunctions/tests/ConstantSPOSet.cpp

+55-28
Original file line numberDiff line numberDiff line change
@@ -13,81 +13,99 @@
1313

1414
namespace qmcplusplus
1515
{
16-
ConstantSPOSet::ConstantSPOSet(const std::string& my_name, const int nparticles, const int norbitals)
17-
: SPOSet(my_name), numparticles_(nparticles)
16+
17+
template<typename T>
18+
ConstantSPOSet<T>::ConstantSPOSet(const std::string& my_name, const int nparticles, const int norbitals)
19+
: SPOSetT<T>(my_name), numparticles_(nparticles)
1820
{
19-
OrbitalSetSize = norbitals;
20-
ref_psi_.resize(numparticles_, OrbitalSetSize);
21-
ref_egrad_.resize(numparticles_, OrbitalSetSize);
22-
ref_elapl_.resize(numparticles_, OrbitalSetSize);
21+
SPOSet::OrbitalSetSize = norbitals;
22+
ref_psi_.resize(numparticles_, SPOSet::OrbitalSetSize);
23+
ref_egrad_.resize(numparticles_, SPOSet::OrbitalSetSize);
24+
ref_elapl_.resize(numparticles_, SPOSet::OrbitalSetSize);
2325

2426
ref_psi_ = 0.0;
2527
ref_egrad_ = 0.0;
2628
ref_elapl_ = 0.0;
2729
};
2830

29-
std::unique_ptr<SPOSet> ConstantSPOSet::makeClone() const
31+
template<typename T>
32+
std::unique_ptr<SPOSetT<T>> ConstantSPOSet<T>::makeClone() const
3033
{
31-
auto myclone = std::make_unique<ConstantSPOSet>(my_name_, numparticles_, OrbitalSetSize);
34+
auto myclone = std::make_unique<ConstantSPOSet>(SPOSet::my_name_, numparticles_, SPOSet::OrbitalSetSize);
3235
myclone->setRefVals(ref_psi_);
3336
myclone->setRefEGrads(ref_egrad_);
3437
myclone->setRefELapls(ref_elapl_);
3538
return myclone;
3639
};
3740

38-
std::string ConstantSPOSet::getClassName() const { return "ConstantSPOSet"; };
41+
template<typename T>
42+
std::string ConstantSPOSet<T>::getClassName() const { return "ConstantSPOSet"; };
3943

40-
void ConstantSPOSet::checkOutVariables(const opt_variables_type& active)
44+
template<typename T>
45+
void ConstantSPOSet<T>::checkOutVariables(const opt_variables_type& active)
4146
{
4247
APP_ABORT("ConstantSPOSet should not call checkOutVariables");
4348
};
4449

45-
void ConstantSPOSet::setOrbitalSetSize(int norbs) { APP_ABORT("ConstantSPOSet should not call setOrbitalSetSize()"); }
50+
template<typename T>
51+
void ConstantSPOSet<T>::setOrbitalSetSize(int norbs) { APP_ABORT("ConstantSPOSet should not call setOrbitalSetSize()"); }
4652

47-
void ConstantSPOSet::setRefVals(const ValueMatrix& vals)
53+
template<typename T>
54+
void ConstantSPOSet<T>::setRefVals(const ValueMatrix& vals)
4855
{
49-
assert(vals.cols() == OrbitalSetSize);
56+
assert(vals.cols() == SPOSet::OrbitalSetSize);
5057
assert(vals.rows() == numparticles_);
5158
ref_psi_ = vals;
5259
};
53-
void ConstantSPOSet::setRefEGrads(const GradMatrix& grads)
60+
61+
template<typename T>
62+
void ConstantSPOSet<T>::setRefEGrads(const GradMatrix& grads)
5463
{
55-
assert(grads.cols() == OrbitalSetSize);
64+
assert(grads.cols() == SPOSet::OrbitalSetSize);
5665
assert(grads.rows() == numparticles_);
5766
ref_egrad_ = grads;
5867
};
59-
void ConstantSPOSet::setRefELapls(const ValueMatrix& lapls)
68+
69+
template<typename T>
70+
void ConstantSPOSet<T>::setRefELapls(const ValueMatrix& lapls)
6071
{
61-
assert(lapls.cols() == OrbitalSetSize);
72+
assert(lapls.cols() == SPOSet::OrbitalSetSize);
6273
assert(lapls.rows() == numparticles_);
6374
ref_elapl_ = lapls;
6475
};
6576

66-
void ConstantSPOSet::evaluateValue(const ParticleSet& P, int iat, ValueVector& psi)
77+
template<typename T>
78+
void ConstantSPOSet<T>::evaluateValue(const ParticleSet& P, int iat, ValueVector& psi)
6779
{
6880
const auto* vp = dynamic_cast<const VirtualParticleSet*>(&P);
6981
int ptcl = vp ? vp->refPtcl : iat;
70-
assert(psi.size() == OrbitalSetSize);
71-
for (int iorb = 0; iorb < OrbitalSetSize; iorb++)
82+
assert(psi.size() == SPOSet::OrbitalSetSize);
83+
for (int iorb = 0; iorb < SPOSet::OrbitalSetSize; iorb++)
7284
psi[iorb] = ref_psi_(ptcl, iorb);
7385
};
7486

75-
void ConstantSPOSet::evaluateVGL(const ParticleSet& P, int iat, ValueVector& psi, GradVector& dpsi, ValueVector& d2psi)
87+
template<typename T>
88+
void ConstantSPOSet<T>::evaluateVGL(const ParticleSet& P,
89+
int iat,
90+
ValueVector& psi,
91+
GradVector& dpsi,
92+
ValueVector& d2psi)
7693
{
77-
for (int iorb = 0; iorb < OrbitalSetSize; iorb++)
94+
for (int iorb = 0; iorb < SPOSet::OrbitalSetSize; iorb++)
7895
{
7996
psi[iorb] = ref_psi_(iat, iorb);
8097
dpsi[iorb] = ref_egrad_(iat, iorb);
8198
d2psi[iorb] = ref_elapl_(iat, iorb);
8299
}
83100
};
84101

85-
void ConstantSPOSet::evaluate_notranspose(const ParticleSet& P,
86-
int first,
87-
int last,
88-
ValueMatrix& logdet,
89-
GradMatrix& dlogdet,
90-
ValueMatrix& d2logdet)
102+
template<typename T>
103+
void ConstantSPOSet<T>::evaluate_notranspose(const ParticleSet& P,
104+
int first,
105+
int last,
106+
ValueMatrix& logdet,
107+
GradMatrix& dlogdet,
108+
ValueMatrix& d2logdet)
91109
{
92110
for (int iat = first, i = 0; iat < last; ++iat, ++i)
93111
{
@@ -97,4 +115,13 @@ void ConstantSPOSet::evaluate_notranspose(const ParticleSet& P,
97115
evaluateVGL(P, iat, v, g, l);
98116
}
99117
}
118+
119+
#if !defined(MIXED_PRECISION)
120+
template class ConstantSPOSet<double>;
121+
template class ConstantSPOSet<std::complex<double>>;
122+
#endif
123+
template class ConstantSPOSet<float>;
124+
template class ConstantSPOSet<std::complex<float>>;
125+
126+
100127
} //namespace qmcplusplus

src/QMCWaveFunctions/tests/ConstantSPOSet.h

+9-1
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,17 @@ namespace qmcplusplus
2222
* Exists to provide deterministic and known output to objects requiring SPOSet evaluations.
2323
*
2424
*/
25-
class ConstantSPOSet : public SPOSet
25+
template<typename T>
26+
class ConstantSPOSet : public SPOSetT<T>
2627
{
2728
public:
29+
using SPOSet = SPOSetT<T>;
30+
31+
using ValueVector = typename SPOSet::ValueVector;
32+
using ValueMatrix = typename SPOSet::ValueMatrix;
33+
using GradVector = typename SPOSet::GradVector;
34+
using GradMatrix = typename SPOSet::GradMatrix;
35+
2836
ConstantSPOSet(const std::string& my_name) = delete;
2937

3038
//Constructor needs number of particles and number of orbitals. This is the minimum

src/QMCWaveFunctions/tests/test_CompositeSPOSet.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ TEST_CASE("CompositeSPO::diamond_1x1x1", "[wavefunction")
3535
auto& pset = *particle_pool.getParticleSet("e");
3636
auto& twf = *wavefunction_pool.getWaveFunction("wavefunction");
3737

38-
CompositeSPOSet comp_sposet("one_composite_set");
38+
CompositeSPOSet<SPOSet::ValueType> comp_sposet("one_composite_set");
3939

4040
std::vector<std::string> sposets{"spo_ud", "spo_dm"};
4141
for (auto sposet_str : sposets)

0 commit comments

Comments
 (0)