Skip to content

Commit

Permalink
Improve point sampling
Browse files Browse the repository at this point in the history
  • Loading branch information
lululxvi committed Jul 30, 2022
1 parent 26e2a98 commit 9e9cda6
Show file tree
Hide file tree
Showing 5 changed files with 33 additions and 28 deletions.
8 changes: 4 additions & 4 deletions deepxde/data/func_constraint.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,14 +43,14 @@ def losses(self, targets, outputs, loss_fn, inputs, model, aux=None):

@run_if_any_none("train_x", "train_y")
def train_next_batch(self, batch_size=None):
if self.dist_train == "log uniform":
if self.dist_train == "uniform":
self.train_x = self.geom.uniform_points(self.num_train, False)
elif self.dist_train == "log uniform":
self.train_x = self.geom.log_uniform_points(self.num_train, False)
elif self.dist_train == "random":
else:
self.train_x = self.geom.random_points(
self.num_train, random=self.dist_train
)
else:
self.train_x = self.geom.uniform_points(self.num_train, False)
if self.anchors is not None:
self.train_x = np.vstack((self.anchors, self.train_x))
self.train_y = self.func(self.train_x)
Expand Down
10 changes: 5 additions & 5 deletions deepxde/geometry/geometry_nd.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def on_boundary(self, x):
return np.isclose(np.linalg.norm(x - self.center, axis=-1), self.radius)

def distance2boundary_unitdirn(self, x, dirn):
"""https://en.wikipedia.org/wiki/Line%E2%80%93sphere_intersection"""
# https://en.wikipedia.org/wiki/Line%E2%80%93sphere_intersection
xc = x - self.center
ad = np.dot(xc, dirn)
return -ad + (ad ** 2 - np.sum(xc * xc, axis=-1) + self._r2) ** 0.5
Expand All @@ -136,24 +136,24 @@ def boundary_normal(self, x):
return _n

def random_points(self, n, random="pseudo"):
"""https://math.stackexchange.com/questions/87230/picking-random-points-in-the-volume-of-sphere-with-uniform-probability"""
# https://math.stackexchange.com/questions/87230/picking-random-points-in-the-volume-of-sphere-with-uniform-probability
if random == "pseudo":
U = np.random.rand(n, 1)
X = np.random.normal(size=(n, self.dim))
else:
rng = sample(n, self.dim + 1, random)
U, X = rng[:, 0:1], rng[:, 1:]
U, X = rng[:, 0:1], rng[:, 1:] # Error if X = [0, 0, ...]
X = stats.norm.ppf(X)
X = preprocessing.normalize(X)
X = U ** (1 / self.dim) * X
return self.radius * X + self.center

def random_boundary_points(self, n, random="pseudo"):
"""http://mathworld.wolfram.com/HyperspherePointPicking.html"""
# http://mathworld.wolfram.com/HyperspherePointPicking.html
if random == "pseudo":
X = np.random.normal(size=(n, self.dim)).astype(config.real(np))
else:
U = sample(n, self.dim, random)
U = sample(n, self.dim, random) # Error for [0, 0, ...] or [0.5, 0.5, ...]
X = stats.norm.ppf(U)
X = preprocessing.normalize(X)
return self.radius * X + self.center
Expand Down
39 changes: 22 additions & 17 deletions deepxde/geometry/sampler.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
__all__ = ["sample"]

from distutils.version import LooseVersion

import numpy as np
import skopt

Expand Down Expand Up @@ -35,24 +33,31 @@ def pseudorandom(n_samples, dimension):


def quasirandom(n_samples, dimension, sampler):
# Certain points should be removed:
# - Boundary points such as [..., 0, ...]
# - Special points [0, 0, 0, ...] and [0.5, 0.5, 0.5, ...], which cause error in
# Hypersphere.random_points() and Hypersphere.random_boundary_points()
skip = 0
if sampler == "LHS":
sampler = skopt.sampler.Lhs(
lhs_type="centered", criterion="maximin", iterations=1000
)
sampler = skopt.sampler.Lhs()
elif sampler == "Halton":
sampler = skopt.sampler.Halton(min_skip=-1, max_skip=-1)
# 1st point: [0, 0, ...]
sampler = skopt.sampler.Halton(min_skip=1, max_skip=1)
elif sampler == "Hammersley":
sampler = skopt.sampler.Hammersly(min_skip=-1, max_skip=-1)
# 1st point: [0, 0, ...]
if dimension == 1:
sampler = skopt.sampler.Hammersly(min_skip=1, max_skip=1)
else:
sampler = skopt.sampler.Hammersly()
skip = 1
elif sampler == "Sobol":
# Remove the first point [0, 0, ...] and the second point [0.5, 0.5, ...], which
# are too special and may cause some error.
if LooseVersion(skopt.__version__) < LooseVersion("0.9"):
sampler = skopt.sampler.Sobol(min_skip=2, max_skip=2, randomize=False)
# 1st point: [0, 0, ...], 2nd point: [0.5, 0.5, ...]
sampler = skopt.sampler.Sobol(randomize=False)
if dimension < 3:
skip = 1
else:
sampler = skopt.sampler.Sobol(skip=0, randomize=False)
space = [(0.0, 1.0)] * dimension
return np.asarray(
sampler.generate(space, n_samples + 2)[2:], dtype=config.real(np)
)
skip = 2
space = [(0.0, 1.0)] * dimension
return np.asarray(sampler.generate(space, n_samples), dtype=config.real(np))
return np.asarray(
sampler.generate(space, n_samples + skip)[skip:], dtype=config.real(np)
)
2 changes: 1 addition & 1 deletion docs/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
matplotlib
numpy
scikit-learn
scikit-optimize
scikit-optimize>=0.9.0
scipy
docutils<0.18 # https://github.com/readthedocs/readthedocs.org/issues/8616

Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
matplotlib
numpy
scikit-learn
scikit-optimize
scikit-optimize>=0.9.0
scipy

0 comments on commit 9e9cda6

Please sign in to comment.