Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add palace::GridFunction to unify mfem::ParGridFunction and mfem::ParComplexGridFunction #204

Merged
merged 3 commits into from
Mar 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion palace/drivers/basesolver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -568,7 +568,7 @@ void BaseSolver::PostprocessProbes(const PostOperator &postop, const std::string
{
return;
}
const bool has_imaginary = postop.HasImaginary();
const bool has_imaginary = postop.HasImag();
for (int f = 0; f < 2; f++)
{
// Probe data is ordered as [Fx1, Fy1, Fz1, Fx2, Fy2, Fz2, ...].
Expand Down
1 change: 1 addition & 0 deletions palace/fem/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ target_sources(${LIB_TARGET_NAME}
${CMAKE_CURRENT_SOURCE_DIR}/coefficient.cpp
${CMAKE_CURRENT_SOURCE_DIR}/errorindicator.cpp
${CMAKE_CURRENT_SOURCE_DIR}/fespace.cpp
${CMAKE_CURRENT_SOURCE_DIR}/gridfunction.cpp
${CMAKE_CURRENT_SOURCE_DIR}/integrator.cpp
${CMAKE_CURRENT_SOURCE_DIR}/interpolator.cpp
${CMAKE_CURRENT_SOURCE_DIR}/lumpedelement.cpp
Expand Down
18 changes: 12 additions & 6 deletions palace/fem/bilinearform.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -260,14 +260,20 @@ std::unique_ptr<ceed::Operator> DiscreteLinearOperator::PartialAssemble() const
Vector test_multiplicity(test_fespace.GetVSize());
test_multiplicity = 0.0;
auto *h_mult = test_multiplicity.HostReadWrite();
mfem::Array<int> dofs;
for (int i = 0; i < test_fespace.GetMesh().GetNE(); i++)
PalacePragmaOmp(parallel)
{
test_fespace.Get().GetElementVDofs(i, dofs);
for (int j = 0; j < dofs.Size(); j++)
mfem::Array<int> dofs;
mfem::DofTransformation dof_trans;
PalacePragmaOmp(for schedule(static))
for (int i = 0; i < test_fespace.GetMesh().GetNE(); i++)
{
const int k = dofs[j];
h_mult[(k >= 0) ? k : -1 - k] += 1.0;
test_fespace.Get().GetElementVDofs(i, dofs, dof_trans);
for (int j = 0; j < dofs.Size(); j++)
{
const int k = dofs[j];
PalacePragmaOmp(atomic update)
h_mult[(k >= 0) ? k : -1 - k] += 1.0;
}
}
}
test_multiplicity.UseDevice(true);
Expand Down
63 changes: 24 additions & 39 deletions palace/fem/coefficient.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include <utility>
#include <vector>
#include <mfem.hpp>
#include "fem/gridfunction.hpp"
#include "models/materialoperator.hpp"

namespace palace
Expand Down Expand Up @@ -350,19 +351,19 @@ enum class EnergyDensityType
// Returns the local energy density evaluated as 1/2 Dᴴ E or 1/2 Bᴴ H for real-valued
// material coefficients. For internal boundary elements, the solution is taken on the side
// of the element with the larger-valued material property (permittivity or permeability).
template <EnergyDensityType Type, typename GridFunctionType>
template <EnergyDensityType Type>
class EnergyDensityCoefficient : public mfem::Coefficient, public BdrGridFunctionCoefficient
{
private:
const GridFunctionType &U;
const GridFunction &U;
const MaterialOperator &mat_op;
mfem::Vector V;

double GetLocalEnergyDensity(mfem::ElementTransformation &T,
const mfem::IntegrationPoint &ip, int attr);

public:
EnergyDensityCoefficient(const GridFunctionType &gf, const MaterialOperator &mat_op)
EnergyDensityCoefficient(const GridFunction &gf, const MaterialOperator &mat_op)
: mfem::Coefficient(), BdrGridFunctionCoefficient(*gf.ParFESpace()->GetParMesh()),
U(gf), mat_op(mat_op), V(mat_op.SpaceDimension())
{
Expand Down Expand Up @@ -399,49 +400,33 @@ class EnergyDensityCoefficient : public mfem::Coefficient, public BdrGridFunctio
};

template <>
inline double
EnergyDensityCoefficient<EnergyDensityType::ELECTRIC, mfem::ParComplexGridFunction>::
GetLocalEnergyDensity(mfem::ElementTransformation &T, const mfem::IntegrationPoint &ip,
int attr)
inline double EnergyDensityCoefficient<EnergyDensityType::ELECTRIC>::GetLocalEnergyDensity(
mfem::ElementTransformation &T, const mfem::IntegrationPoint &ip, int attr)
{
// Only the real part of the permittivity contributes to the energy (imaginary part
// cancels out in the inner product due to symmetry).
U.real().GetVectorValue(T, ip, V);
double res = mat_op.GetPermittivityReal(attr).InnerProduct(V, V);
U.imag().GetVectorValue(T, ip, V);
res += mat_op.GetPermittivityReal(attr).InnerProduct(V, V);
return 0.5 * res;
}

template <>
inline double EnergyDensityCoefficient<EnergyDensityType::ELECTRIC, mfem::ParGridFunction>::
GetLocalEnergyDensity(mfem::ElementTransformation &T, const mfem::IntegrationPoint &ip,
int attr)
{
U.GetVectorValue(T, ip, V);
return 0.5 * mat_op.GetPermittivityReal(attr).InnerProduct(V, V);
}

template <>
inline double
EnergyDensityCoefficient<EnergyDensityType::MAGNETIC, mfem::ParComplexGridFunction>::
GetLocalEnergyDensity(mfem::ElementTransformation &T, const mfem::IntegrationPoint &ip,
int attr)
{
U.real().GetVectorValue(T, ip, V);
double res = mat_op.GetInvPermeability(attr).InnerProduct(V, V);
U.imag().GetVectorValue(T, ip, V);
res += mat_op.GetInvPermeability(attr).InnerProduct(V, V);
return 0.5 * res;
U.Real().GetVectorValue(T, ip, V);
double dot = mat_op.GetPermittivityReal(attr).InnerProduct(V, V);
if (U.HasImag())
{
U.Imag().GetVectorValue(T, ip, V);
dot += mat_op.GetPermittivityReal(attr).InnerProduct(V, V);
}
return 0.5 * dot;
}

template <>
inline double EnergyDensityCoefficient<EnergyDensityType::MAGNETIC, mfem::ParGridFunction>::
GetLocalEnergyDensity(mfem::ElementTransformation &T, const mfem::IntegrationPoint &ip,
int attr)
inline double EnergyDensityCoefficient<EnergyDensityType::MAGNETIC>::GetLocalEnergyDensity(
mfem::ElementTransformation &T, const mfem::IntegrationPoint &ip, int attr)
{
U.GetVectorValue(T, ip, V);
return 0.5 * mat_op.GetInvPermeability(attr).InnerProduct(V, V);
U.Real().GetVectorValue(T, ip, V);
double dot = mat_op.GetInvPermeability(attr).InnerProduct(V, V);
if (U.HasImag())
{
U.Imag().GetVectorValue(T, ip, V);
dot += mat_op.GetInvPermeability(attr).InnerProduct(V, V);
}
return 0.5 * dot;
}

// Returns the local field evaluated on a boundary element. For internal boundary elements,
Expand Down
60 changes: 60 additions & 0 deletions palace/fem/gridfunction.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0

#include "gridfunction.hpp"

#include "fem/fespace.hpp"

namespace palace
{

GridFunction::GridFunction(mfem::ParFiniteElementSpace &fespace, bool complex)
: gfr(&fespace)
{
if (complex)
{
gfi.SetSpace(&fespace);
}
}

GridFunction::GridFunction(FiniteElementSpace &fespace, bool complex)
: GridFunction(fespace.Get(), complex)
{
}

GridFunction &GridFunction::operator=(std::complex<double> s)
{
Comment on lines +25 to +26
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

MFEM_ASSERT(s.imag() == 0.0 || !HasImag(), "Cannot assign complex scalar to a non-complex grid function!");

Assigning an actually complex into a real grid function is almost certainly an error.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added in 5f0543e

Real() = s.real();
if (HasImag())
{
Imag() = s.imag();
}
else
{
MFEM_ASSERT(
s.imag() == 0.0,
"Cannot assign complex scalar to a non-complex-valued GridFunction object!");
}
return *this;
}

GridFunction &GridFunction::operator*=(double s)
{
Real() *= s;
if (HasImag())
{
Imag() *= s;
}
return *this;
}

void GridFunction::Update()
{
Real().Update();
if (HasImag())
{
Imag().Update();
}
}

} // namespace palace
73 changes: 73 additions & 0 deletions palace/fem/gridfunction.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0

#ifndef PALACE_FEM_GRIDFUNCTION_HPP
#define PALACE_FEM_GRIDFUNCTION_HPP

#include <mfem.hpp>

namespace palace
{

class FiniteElementSpace;

//
// A real- or complex-valued grid function represented as two real grid functions, one for
// each component. This unifies mfem::ParGridFunction and mfem::ParComplexGridFunction, and
// replaces the latter due to some issues observed with memory aliasing on GPUs.
//
class GridFunction
{
private:
mfem::ParGridFunction gfr, gfi;
Comment on lines +19 to +22
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The alternative here is something like

template <bool Complex = false>
class GridFunction
{
private:
  std::array<mfem::ParGridFunction, Complex ? 2 : 1> gfs;

public:
  GridFunction(mfem::ParFiniteElementSpace &fespace);
  GridFunction(FiniteElementSpace &fespace);

  // Get access to the real and imaginary grid function parts.
  const mfem::ParGridFunction &Real() const { return gfs[0]; }
  mfem::ParGridFunction &Real() { return gf[0]; }

  template <bool Enable = Complex, typename =std::enable_if_t<Complex>>
  const mfem::ParGridFunction &Imag() const
  {
    return gfs[1];
  }

  template <bool Enable = Complex, typename =std::enable_if_t<Complex>>
  mfem::ParGridFunction &Imag()
  {
    return gfs[1];
  }

  // Check if the grid function is suited for storing complex-valued fields.
  bool HasImag() const { return Complex; }

  // Get access to the underlying finite element space (match MFEM interface).
  mfem::FiniteElementSpace *FESpace() { return gfr.FESpace(); }
  const mfem::FiniteElementSpace *FESpace() const { return gfr.FESpace(); }
  mfem::ParFiniteElementSpace *ParFESpace() { return gfr.ParFESpace(); }
  const mfem::ParFiniteElementSpace *ParFESpace() const { return gfr.ParFESpace(); }

  // Set all entries equal to s.

  template <bool Enable = Complex, typename =std::enable_if_t<Enable>>
  GridFunction &operator=(std::complex<double> s);

  GridFunction &operator=(double s);

  // Scale all entries by s.
  GridFunction &operator*=(double s);

  // Transform for space update (for example on mesh change).
  void Update();

  // Get the associated MPI communicator.
  MPI_Comm GetComm() const { return ParFESpace()->GetComm(); }
};

where the enable_if bits remove the imaginary methods (and operators) in the case of non-complex. This would mirror the existing the split between the two grid function types, with all the corresponding dispatch.

I think the runtime flag here is a better approach though, as the template is only going to eliminate if branches at a very high level, so the template isn't really worth it imo.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I considered the template version but came to the same conclusion you have, that the runtime flag is just fine in this case from a performance and from an interface point of view.


public:
GridFunction(mfem::ParFiniteElementSpace &fespace, bool complex = false);
GridFunction(FiniteElementSpace &fespace, bool complex = false);

// Get access to the real and imaginary grid function parts.
const mfem::ParGridFunction &Real() const { return gfr; }
mfem::ParGridFunction &Real() { return gfr; }
const mfem::ParGridFunction &Imag() const
{
MFEM_ASSERT(HasImag(),
"Invalid access of imaginary part of a real-valued GridFunction object!");
return gfi;
}
mfem::ParGridFunction &Imag()
{
MFEM_ASSERT(HasImag(),
"Invalid access of imaginary part of a real-valued GridFunction object!");
return gfi;
}

// Check if the grid function is suited for storing complex-valued fields.
bool HasImag() const { return (gfi.ParFESpace() != nullptr); }

// Get access to the underlying finite element space (match MFEM interface).
mfem::FiniteElementSpace *FESpace() { return gfr.FESpace(); }
const mfem::FiniteElementSpace *FESpace() const { return gfr.FESpace(); }
mfem::ParFiniteElementSpace *ParFESpace() { return gfr.ParFESpace(); }
const mfem::ParFiniteElementSpace *ParFESpace() const { return gfr.ParFESpace(); }

// Set all entries equal to s.
GridFunction &operator=(std::complex<double> s);
GridFunction &operator=(double s)
{
*this = std::complex<double>(s, 0.0);
return *this;
}

// Scale all entries by s.
GridFunction &operator*=(double s);

// Transform for space update (for example on mesh change).
void Update();

// Get the associated MPI communicator.
MPI_Comm GetComm() const { return ParFESpace()->GetComm(); }
};

} // namespace palace

#endif // PALACE_FEM_GRIDFUNCTION_HPP
10 changes: 5 additions & 5 deletions palace/fem/interpolator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include "interpolator.hpp"

#include <algorithm>
#include "fem/gridfunction.hpp"
#include "utils/communication.hpp"
#include "utils/iodata.hpp"

Expand Down Expand Up @@ -88,13 +89,12 @@ std::vector<double> InterpolationOperator::ProbeField(const mfem::ParGridFunctio
#endif
}

std::vector<std::complex<double>>
InterpolationOperator::ProbeField(const mfem::ParComplexGridFunction &U, bool has_imaginary)
std::vector<std::complex<double>> InterpolationOperator::ProbeField(const GridFunction &U)
{
std::vector<double> vr = ProbeField(U.real());
if (has_imaginary)
std::vector<double> vr = ProbeField(U.Real());
if (U.HasImag())
{
std::vector<double> vi = ProbeField(U.imag());
std::vector<double> vi = ProbeField(U.Imag());
std::vector<std::complex<double>> vals(vr.size());
std::transform(vr.begin(), vr.end(), vi.begin(), vals.begin(),
[](double xr, double xi) { return std::complex<double>(xr, xi); });
Expand Down
8 changes: 4 additions & 4 deletions palace/fem/interpolator.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
namespace palace
{

class GridFunction;
class IoData;

//
Expand All @@ -24,15 +25,14 @@ class InterpolationOperator
#endif
std::vector<int> op_idx;

std::vector<double> ProbeField(const mfem::ParGridFunction &U);

public:
InterpolationOperator(const IoData &iodata, mfem::ParMesh &mesh);

const auto &GetProbes() const { return op_idx; }

std::vector<double> ProbeField(const mfem::ParGridFunction &U);

std::vector<std::complex<double>> ProbeField(const mfem::ParComplexGridFunction &U,
bool has_imaginary);
std::vector<std::complex<double>> ProbeField(const GridFunction &U);
};

} // namespace palace
Expand Down
Loading