Skip to content

Commit

Permalink
Separates pad and extract, and allows double prec.
Browse files Browse the repository at this point in the history
Previously padding the model with cells to make spatial finite
difference cleaner, and for the PML, was done inside scalar.py. They are
now separate differentiable modules that are chained together before
calling the propagator. This is more elegant, but it also improves the
accuracy of backpropagation as the effect of the padded area is now
included (an admittedly small effect). The cost of this is that the
imaging condition in the propagator now has to be applied to the whole
domain (i.e., including padding) rather than just the interior -
somewhat increasing memory usage and computation and also requiring
changes to the imaging condition so that it is still correct in the PML
regions. New tests accompany the additional modules.

The other major change is that the CPU propagator is now compiled for
both single and double precision. Double precision is unlikely to be of
interest for seismic propagation, but it is useful for testing the
propagator. PyTorch's gradcheck is now used to verify the propagator
instead of custom tests previously used. Fully accounting for the effect
of the PML in the 3D propagator would be expensive as it would require
storing additional wavefields, so the expensive terms are omitted. As a
result, the 3D gradient is not exactly right, and so the accuracy limit
of gradcheck has to be increased for that case. The gradcheck tests also
fail for the GPU propagator (at least in 2D and 3D) due to them not
being reentrant. I believe this is because functions such as atomicAdd
are not deterministic. When I checked, the difference between repeated
calls to the 2D propagator was O(1e-10), so is in reality insignificant.
Double precision is not enabled for the GPU as atomicAdd of doubles only
works on newer GPUs. This could be enabled by editing setup.py.
  • Loading branch information
ar4 committed Aug 17, 2018
1 parent 10729ec commit fdf95a9
Show file tree
Hide file tree
Showing 18 changed files with 1,618 additions and 1,387 deletions.
1 change: 1 addition & 0 deletions deepwave/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Wave propagation modules for PyTorch."""

import deepwave.base
import deepwave.scalar
import deepwave.utils
import deepwave.wavelets
6 changes: 6 additions & 0 deletions deepwave/base/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
"""Base classes for Deepwave."""

import deepwave.base.model
import deepwave.base.propagator
import deepwave.base.pad
import deepwave.base.extract
153 changes: 153 additions & 0 deletions deepwave/base/extract.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
"""Extract Module to extract model around survey."""
import math
import torch


class Extract(torch.nn.Module):
"""Extract a portion of the model containing the sources and receivers.
Args:
survey_pad: Float or None or list of such of length
2 * dimensionality of model, specifying padding around sources and
receivers in units of dx.
"""

def __init__(self, survey_pad):
super(Extract, self).__init__()
self.survey_pad = survey_pad

def forward(self, model, source_locations, receiver_locations):
"""Perform extraction.
Args:
model: A Model object
source_locations: A Tensor containing source locations in units
of dx
receiver_locations: A Tensor containing receiver locations in
units of dx
Returns:
Extracted model as a Model object
"""
survey_pad = _set_survey_pad(self.survey_pad, model.ndim)
survey_extents = _get_survey_extents(
model.shape,
model.dx,
survey_pad,
source_locations,
receiver_locations)
extracted_model = _extract_model(model, survey_extents)
return extracted_model


def _set_survey_pad(survey_pad, ndim):
"""Check survey_pad, and convert to a list if it is a scalar."""
# Expand to list
if isinstance(survey_pad, (float, type(None))):
survey_pad = [survey_pad] * 2 * ndim

# Check is non-negative or None
if not all((pad is None) or (pad >= 0) for pad in survey_pad):
raise RuntimeError('survey_pad must be non-negative or None, '
'but got {}'.format(survey_pad))

# Check has correct size
if len(survey_pad) != 2 * ndim:
raise RuntimeError('survey_pad must have length 2 * dims in model, '
'but got {}'.format(len(survey_pad)))

return survey_pad


def _get_survey_extents(model_shape, dx, survey_pad, source_locations,
receiver_locations):
"""Calculate the extents of the model to use for the survey.
Args:
model_shape: A tuple containing the shape of the full model
dx: A Tensor containing the cell spacing in each dimension
survey_pad: A list with two entries for
each dimension, specifying the padding to add
around the sources and receivers included in all of the
shots being propagated. If None, the padding continues
to the edge of the model
source_locations: A Tensor containing source locations
receiver_locations: A Tensor containing receiver locations
Returns:
A list of slices of the same length as the model shape,
specifying the extents of the model that will be
used for wave propagation
"""
ndims = len(dx)
extents = []
for dim in range(ndims):
left_pad = survey_pad[dim * 2]
left_extent = \
_get_survey_extents_one_side(left_pad, 'left',
source_locations[..., dim],
receiver_locations[..., dim],
model_shape[dim],
dx[dim].item())

right_pad = survey_pad[dim * 2 + 1]
right_extent = \
_get_survey_extents_one_side(right_pad, 'right',
source_locations[..., dim],
receiver_locations[..., dim],
model_shape[dim],
dx[dim].item())

extents.append(slice(left_extent, right_extent))

return extents


def _get_survey_extents_one_side(pad, side, source_locations,
receiver_locations, shape, dx):
"""Get the survey extent for the left or right side of one dimension.
Args:
pad: Positive float specifying padding for the side
side: 'left' or 'right'
source/receiver_locations: Tensor with coordinates for the current
dimension
shape: Int specifying length of full model in current dimension
dx: Float specifying cell spacing in current dimension
Returns:
Min/max index as int or None
"""
if pad is None:
return None
if side == 'left':
pad = -pad
op = torch.min
nearest = math.floor
else:
pad = +pad
op = torch.max
nearest = math.ceil
extreme_source = op(source_locations + pad)
extreme_receiver = op(receiver_locations + pad)
extreme_cell = nearest(op(extreme_source, extreme_receiver).item() / dx)
if side == 'right':
extreme_cell += 1
if (extreme_cell <= 0) or (extreme_cell >= shape):
extreme_cell = None
return extreme_cell


def _extract_model(model, extents):
"""Extract the specified portion of the model.
Args:
model: A Model object
extents: A list of slices specifying the portion of the model to
extract
Returns:
A Model containing the desired portion of the model
"""

return model[extents]
Loading

0 comments on commit fdf95a9

Please sign in to comment.