"""
specter.psf.psf
===============
Base class for 2D PSFs
Provides PSF base class which defines the interface for other code
using PSFs. Subclasses implement specific models of the PSF and
override/extend the __init__ and xypix(ispec, wavelength) methods,
while allowing interchangeable use of different PSF models through
the interface defined in this base class.
Stephen Bailey, Fall 2012
"""
from __future__ import absolute_import, division, print_function, unicode_literals
import sys
import numbers
import numpy as np
from numpy.polynomial.legendre import Legendre, legval, legfit
import scipy.optimize
import scipy.sparse
from specter.util import legval_numba
from specter.util import gausspix, TraceSet, CacheDict
from astropy.io import fits
[docs]class PSF(object):
"""
Base class for 2D PSFs
Subclasses need to extend __init__ to load format-specific items
from the input fits file and implement _xypix(ispec, wavelength)
to return xslice, yslice, pixels[y,x] for the PSF evaluated at
spectrum ispec at the given wavelength. All interactions with PSF
classes should be via the methods defined here, allowing
interchangeable use of different PSF models.
"""
def __init__(self, filename):
"""
Load PSF parameters from a file
Loads x, y, wavelength information for spectral traces and fills:
self.npix_x #- number of columns in the target image
self.npix_y #- number of rows in the target image
self.nspec #- number of spectra (fibers)
self.nwave #- number of wavelength samples per spectrum
Subclasses of this class define the xypix(ispec, wavelength) method
to access the projection of this PSF into pixels.
"""
#- Load basic dimensions
hdr = fits.getheader(filename)
self.npix_x = hdr['NPIX_X']
self.npix_y = hdr['NPIX_Y']
self.nspec = hdr['NSPEC']
#- PSF model error
if 'PSFERR' in hdr:
self.psferr = hdr['PSFERR']
else:
self.psferr = 0.01
#- Load x, y legendre coefficient tracesets
with fits.open(filename) as fx:
xc = fx['XCOEFF'].data
hdr = fx['XCOEFF'].header
self._x = TraceSet(xc, domain=(hdr['WAVEMIN'], hdr['WAVEMAX']))
yc = fx['YCOEFF'].data
hdr = fx['YCOEFF'].header
self._y = TraceSet(yc, domain=(hdr['WAVEMIN'], hdr['WAVEMAX']))
#- Create inverse y -> wavelength mapping
self._w = self._y.invert()
#- Cache min/max wavelength per fiber at pixel edges
self._wmin_spec = self.wavelength(None, -0.5)
self._wmax_spec = self.wavelength(None, self.npix_y-0.5)
self._wmin = np.min(self._wmin_spec)
self._wmin_all = np.max(self._wmin_spec)
self._wmax = np.max(self._wmax_spec)
self._wmax_all = np.min(self._wmax_spec)
#- Filled only if needed
self._xsigma = None
self._ysigma = None
#- Utility function to fit spot sigma vs. wavelength
[docs] def _fit_spot_sigma(self, ispec, axis=0, npoly=5):
"""
Fit the cross-sectional Gaussian sigma of PSF spots vs. wavelength.
Return callable Legendre object.
Arguments:
ispec : spectrum number
axis : 0 or 'x' for cross dispersion sigma;
1 or 'y' or 'w' for wavelength dispersion
npoly : order of Legendre poly to fit to sigma vs. wavelength
Returns:
legfit such that legfit(w) returns fit at wavelengths w
"""
if type(axis) is not int:
if axis in ('x', 'X'):
axis = 0
elif axis in ('y', 'Y', 'w', 'W'):
axis = 1
else:
raise ValueError("Unknown axis type {}".format(axis))
if axis not in (0,1):
raise ValueError("axis must be 0, 'x', 1, 'y', or 'w'")
yy = np.linspace(10, self.npix_y-10, 20)
ww = self.wavelength(ispec, y=yy)
xsig = list() #- sigma vs. wavelength array to fill
for w in ww:
xspot = self.pix(ispec, w).sum(axis=axis)
xspot /= np.sum(xspot) #- normalize for edge cases
xx = np.arange(len(xspot))
mean, sigma = scipy.optimize.curve_fit(gausspix, xx, xspot)[0]
xsig.append(sigma)
#- Fit Legendre polynomial and return coefficients
legfit = Legendre.fit(ww, xsig, npoly, domain=(self._wmin, self._wmax))
return legfit
#-------------------------------------------------------------------------
#- Cross dispersion width for row-by-row extractions
[docs] def xsigma(self, ispec, wavelength):
"""
Return Gaussian sigma of PSF spot in cross-dispersion direction
in CCD pixel units.
ispec : spectrum index
wavelength : scalar or vector wavelength(s) to evaluate spot sigmas
The first time this is called for a spectrum, the PSF is sampled
at 20 wavelengths and the variation is fit with a 5th order
Legendre polynomial and the coefficients are cached.
The actual value (and subsequent calls) use these cached
Legendre fits to interpolate the sigma value. If this is not
fast enough and/or accurate enough, PSF subtypes may override
this function to provide a more accurate xsigma measurement.
"""
#- First call for any spectrum: setup array to cache coefficients
if self._xsigma is None:
self._xsigma = [None,] * self.nspec
#- First call for this spectrum: calculate coefficients & cache
if self._xsigma[ispec] is None:
self._xsigma[ispec] = self._fit_spot_sigma(ispec, axis=0, npoly=5)
#- Use cached Legendre fit to interpolate xsigma at wavelength(s)
return self._xsigma[ispec](wavelength)
#-------------------------------------------------------------------------
#- Cross dispersion width for row-by-row extractions
[docs] def ysigma(self, ispec, wavelength):
"""
Return Gaussian sigma of PSF spot in wavelength-dispersion direction
in units of pixels.
Also see wdisp(...) which returns sigmas in units of Angstroms.
ispec : spectrum index
wavelength : scalar or vector wavelength(s) to evaluate spot sigmas
See notes in xsigma(...) about caching of Legendre fit coefficients.
"""
#- First call for any spectrum: setup array to cache coefficients
if self._ysigma is None:
self._ysigma = [None,] * self.nspec
#- First call for this spectrum: calculate coefficients & cache
if self._ysigma[ispec] is None:
self._ysigma[ispec] = self._fit_spot_sigma(ispec, axis=1, npoly=5)
#- Use cached Legendre fit to interpolate xsigma at wavelength(s)
return self._ysigma[ispec](wavelength)
#-------------------------------------------------------------------------
#- Cross dispersion width for row-by-row extractions
[docs] def wdisp(self, ispec, wavelength):
"""
Return Gaussian sigma of PSF spot in wavelength-dispersion direction
in units of Angstroms.
Also see ysigma(...) which returns sigmas in units of pixels.
ispec : spectrum index
wavelength : scalar or vector wavelength(s) to evaluate spot sigmas
See notes in xsigma(...) about caching of Legendre fit coefficients.
"""
sigma_pix = self.ysigma(ispec, wavelength)
return self.angstroms_per_pixel(ispec, wavelength) * sigma_pix
#-------------------------------------------------------------------------
#- Evaluate the PSF into pixels
[docs] def pix(self, ispec, wavelength):
"""
Evaluate PSF for spectrum[ispec] at given wavelength
returns 2D array pixels[iy,ix]
also see xypix(ispec, wavelength)
"""
return self.xypix(ispec, wavelength)[2]
[docs] def _xypix(self, ispec, wavelength, ispec_cache=None, iwave_cache=None):
"""
Subclasses of PSF should implement this to return
xslice, yslice, pixels[iy,ix] for their particular
models. Don't worry about edge effects -- PSF.xypix
will take care of that.
"""
raise NotImplementedError
[docs] def xypix(self, ispec, wavelength, xmin=0, xmax=None, ymin=0, ymax=None, ispec_cache=None, iwave_cache=None):
"""
Evaluate PSF for spectrum[ispec] at given wavelength
returns xslice, yslice, pixels[iy,ix] such that
image[yslice,xslice] += photons*pixels adds the contribution from
spectrum ispec at that wavelength.
if xmin or ymin are set, the slices are relative to those
minima (useful for simulating subimages)
Optional inputs:
ispec_cache = an index into the spectrum number that starts again at 0 for each patch
iwave_cache = an index into the wavelength number that starts again at 0 for each patch
"""
if xmax is None:
xmax = self.npix_x
if ymax is None:
ymax = self.npix_y
if wavelength < self._wmin_spec[ispec]:
return slice(0,0), slice(0,0), np.zeros((0,0))
elif wavelength > self._wmax_spec[ispec]:
return slice(0,0), slice(ymax, ymax), np.zeros((0,0))
key = (ispec, wavelength)
try:
if key in self._cache:
xx, yy, ccdpix = self._cache[key]
else:
xx, yy, ccdpix = self._xypix(ispec, wavelength, ispec_cache=ispec_cache, iwave_cache=iwave_cache)
self._cache[key] = (xx, yy, ccdpix)
except AttributeError:
self._cache = CacheDict(2500)
xx, yy, ccdpix = self._xypix(ispec, wavelength, ispec_cache=ispec_cache, iwave_cache=iwave_cache)
xlo, xhi = xx.start, xx.stop
ylo, yhi = yy.start, yy.stop
#- Check if completely off the edge in any direction
if (ylo >= ymax):
return slice(0,0), slice(ymax,ymax), np.zeros( (0,0) )
elif (yhi < ymin):
return slice(0,0), slice(ymin,ymin), np.zeros( (0,0) )
elif (xlo >= xmax):
return slice(xmax, xmax), slice(0,0), np.zeros( (0,0) )
elif (xhi <= xmin):
return slice(xmin, xmin), slice(0,0), np.zeros( (0,0) )
#- Check if partially off edge
if xlo < xmin:
ccdpix = ccdpix[:, -(xhi-xmin):]
xlo = xmin
elif xhi > xmax:
ccdpix = ccdpix[:, 0:(xmax-xlo)]
xhi = xmax
if ylo < ymin:
ccdpix = ccdpix[-(yhi-ymin):, ]
ylo = ymin
elif yhi > ymax:
ccdpix = ccdpix[0:(ymax-ylo), :]
yhi = ymax
xx = slice(xlo-xmin, xhi-xmin)
yy = slice(ylo-ymin, yhi-ymin)
#- Check if we are off the edge
if (xx.stop-xx.start == 0) or (yy.stop-yy.start == 0):
ccdpix = np.zeros( (0,0) )
return xx, yy, ccdpix
[docs] def xyrange(self, spec_range, wavelengths):
"""
Return recommended range of pixels which cover these spectra/fluxes:
(xmin, xmax, ymin, ymax)
spec_range = indices specmin,specmax (python style indexing),
or scalar for single spectrum index
wavelengths = wavelength range wavemin,wavemax inclusive
or sorted array of wavelengths
BUG: will fail if asking for a range where one of the spectra is
completely off the CCD
"""
if isinstance(spec_range, numbers.Integral):
specmin, specmax = spec_range, spec_range+1
else:
specmin, specmax = spec_range
if isinstance(wavelengths, numbers.Real):
wavemin = wavemax = wavelengths
else:
wavemin, wavemax = wavelengths[0], wavelengths[-1]
if wavemin < self.wmin:
wavemin = self.wmin
if wavemax > self.wmax:
wavemax = self.wmax
#- Find the spectra with the smallest/largest y centroids
ispec_ymin = specmin + np.argmin(self.y(None, wavemin)[specmin:specmax])
ispec_ymax = specmin + np.argmax(self.y(None, wavemax)[specmin:specmax])
ymin = self.xypix(ispec_ymin, wavemin)[1].start
ymax = self.xypix(ispec_ymax, wavemax)[1].stop
#- Now for wavelength where x = min(x),
#- while staying on CCD and within wavelength range
w = self.wavelength(specmin)
if w[0] < wavemin:
w = w[wavemin <= w]
if wavemax < w[-1]:
w = w[w <= wavemax]
#- Add in wavemin and wavemax since w isn't perfect resolution
w = np.concatenate( (w, (wavemin, wavemax) ) )
#- Trim xy to where specmin is on the CCD
#- Note: Pixel coordinates are from *center* of pixel, thus -0.5
x, y = self.xy(specmin, w)
onccd = (0 <= y-0.5) & (y < self.npix_y-0.5)
x = x[onccd]
w = w[onccd]
if min(x) < 0:
xmin = 0.0
else:
wxmin = w[np.argmin(x)] #- wavelength at x minimum
xmin = self.xypix(specmin, wxmin)[0].start
#- and wavelength where x = max(x)
w = self.wavelength(specmax-1)
if w[0] < wavemin:
w = w[wavemin <= w]
if wavemax < w[-1]:
w = w[w <= wavemax]
#- Add in wavemin and wavemax since w isn't perfect resolution
w = np.concatenate( (w, (wavemin, wavemax) ) )
#- Trim xy to where specmax-1 is on the CCD
#- Note: Pixel coordinates are from *center* of pixel, thus -0.5
x, y = self.xy(specmax-1, w)
onccd = (-0.5 <= y) & (y < self.npix_y-0.5)
x = x[onccd]
w = w[onccd]
if max(x) > self.npix_x:
xmax = self.npix_x
else:
wxmax = w[np.argmax(x)]
#- use _xypix not xypix to avoid corner-case rounding when
#- very near edge of CCD
xmax = self._xypix(specmax-1, wxmax)[0].stop
return (xmin, xmax, ymin, ymax)
#-------------------------------------------------------------------------
#- Shift PSF to a new x,y grid, e.g. to account for flexure
[docs] def shift_xy(self, dx, dy):
"""
Shift the x,y trace locations of this PSF while preserving
wavelength grid: xnew = x + dx, ynew = y + dy
"""
raise NotImplementedError
#-------------------------------------------------------------------------
#- accessors for x, y, wavelength
[docs] def x(self, ispec=None, wavelength=None):
"""
Return CCD X centroid of spectrum ispec at given wavelength(s).
ispec can be None, scalar, or vector
wavelength can be None, scalar or a vector
ispec wavelength returns
+-------+-----------+------
None None array[nspec, npix_y]
None scalar vector[nspec]
None vector array[nspec, nwave]
scalar None array[npix_y]
scalar scalar scalar
scalar vector vector[nwave]
vector None array[nspec, npix_y]
vector scalar vector[nspec]
vector vector array[nspec, nwave]
"""
if wavelength is None:
#- ispec=None -> ispec=every spectrum
if ispec is None:
ispec = np.arange(self.nspec, dtype=int)
#- ispec is an array; sample at every row
if isinstance(ispec, (np.ndarray, list, tuple)):
x = list()
for i in ispec:
w = self.wavelength(i)
x.append(self._x.eval(i, w))
return np.array(x)
else: #- scalar ispec, make wavelength an array
wavelength = self.wavelength(ispec)
return self._x.eval(ispec, wavelength)
[docs] def y(self, ispec=None, wavelength=None):
"""
Return CCD Y centroid of spectrum ispec at given wavelength(s).
ispec can be None, scalar, or vector
wavelength can be scalar or a vector (but not None)
ispec wavelength returns
+-------+-----------+------
None scalar vector[nspec]
None vector array[nspec,nwave]
scalar scalar scalar
scalar vector vector[nwave]
vector scalar vector[nspec]
vector vector array[nspec, nwave]
"""
if wavelength is None:
raise ValueError("PSF.y requires wavelength scalar or vector")
if ispec is None:
ispec = np.arange(self.nspec)
return self._y.eval(ispec, wavelength)
if ispec is None:
if wavelength is None:
return np.tile(np.arange(self.npix_y), self.nspec).reshape(self.nspec, self.npix_y)
else:
ispec = np.arange(self.nspec, dtype=int)
if wavelength is None:
wavelength = self.wavelength(ispec)
return self._y.eval(ispec, wavelength)
[docs] def xy(self, ispec=None, wavelength=None):
"""
Utility function to return self.x(...) and self.y(...) in one call
"""
x = self.x(ispec, wavelength)
y = self.y(ispec, wavelength)
return x, y
[docs] def wavelength(self, ispec=None, y=None):
"""
Return wavelength of spectrum[ispec] evaluated at y.
ispec can be None, scalar, or vector
y can be None, scalar, or vector
May return a view of the underlying array; do not modify unless
specifying copy=True to get a copy of the data.
"""
if y is None:
y = np.arange(0, self.npix_y)
if ispec is None:
ispec = np.arange(self.nspec, dtype=int)
return self._w.eval(ispec, y)
[docs] def angstroms_per_pixel(self, ispec, wavelength):
"""
Return CCD pixel width in Angstroms for spectrum ispec at given
wavlength(s). Wavelength may be scalar or array.
"""
ww = self.wavelength(ispec, y=np.arange(self.npix_y))
dw = np.gradient( ww )
return np.interp(wavelength, ww, dw)
#-------------------------------------------------------------------------
#- Project spectra onto CCD pixels
# def project_subimage(self, phot, wavelength, specmin, verbose=False):
# """
# Project photons onto CCD. Returns subimage, (xmin,xmax,ymin,ymax).
# See PSF.project() for full parameter descriptions.
# """
# #- NOTES:
# #- Tightly coupled to self.project
# #- Should this return slices instead of xyrange, similar to
# #- PSF.xypix?
# #- Maybe even rename to xyproject() ?
#
# nspec = phot.shape[0] if phot.ndim == 2 else self.nspec
# specmax = min(specmin+nspec, nspec)
# specrange = (specmin, specmax)
# waverange = (np.min(wavelength), np.max(wavelegth))
# xmin, xmax, ymin, ymax = xyrange = self.xyrange(specrange, waverange)
# image = self.project(wavelength, phot, specmin=specmin, \
# xr=(xmin,xmax), yr=(ymin, ymax), verbose=verbose)
#
# return image, xyrange
[docs] def project(self, wavelength, phot, specmin=0, xyrange=None, verbose=False):
"""
Returns 2D image or 3D images of spectra projected onto the CCD
Required inputs:
phot[nwave] or phot[nspec, nwave] or phot[nimage, nspec, nwave]
as photons on CCD per bin
wavelength[nwave] or wavelength[nspec, nwave] in Angstroms
if wavelength is 1D and spectra is 2D or 3D, then wavelength[]
applies to all phot[i]
Optional inputs:
specmin : starting spectrum number
xyrange : (xmin, xmax, ymin, ymax) range of CCD pixels
if phot is 1D or 2D, output is a single 2D[ny,nx] image
if phot is 3D[nimage,nspec,nwave], output is 3D[nimage,ny,nx]
"""
wavelength = np.asarray(wavelength)
phot = np.asarray(phot)
if specmin >= self.nspec:
raise ValueError('specmin {} >= psf.nspec {}'.format(specmin, self.nspec))
if phot.shape[-1] != wavelength.shape[-1]:
raise ValueError('phot.shape {} vs. wavelength.shape {} mismatch'.format(phot.shape, wavelength.shape))
#- x,y ranges and number of pixels
if xyrange is None:
xmin, xmax = (0, self.npix_x)
ymin, ymax = (0, self.npix_y)
xyrange = (xmin, xmax, ymin, ymax)
else:
xmin, xmax, ymin, ymax = xyrange
nx = xmax - xmin
ny = ymax - ymin
#- convert phot to 3D[nimage, nspec, nwave]
phot = np.atleast_2d(phot)
if phot.ndim == 3:
nimage, nspec, nw = phot.shape
singleimage = False
else:
nspec, nw = phot.shape
nimage = 1
phot = phot.reshape(nimage, nspec, nw)
singleimage = True
if specmin+nspec > self.nspec:
print("WARNING: specmin+nspec ({}+{}) > psf.nspec {}".format(specmin, nspec, self.nspec), file=sys.stderr)
#- Create image to fill
img = np.zeros( (nimage, ny, nx) )
#- Loop over spectra and wavelengths
specmax = min(specmin+nspec, self.nspec)
for i, ispec in enumerate(range(specmin, specmax)):
if verbose:
print(ispec)
#- 1D wavelength for every spec, or 2D wavelength for 2D phot?
if wavelength.ndim == 2:
wspec = wavelength[i]
else:
wspec = wavelength
#- Evaluate positive photons within wavelength range
wmin, wmax = self.wavelength(ispec, y=(0, self.npix_y))
for j, w in enumerate(wspec):
if np.any(phot[:,i,j] != 0.0) and (wmin <= w <= wmax):
xx, yy, pix = self.xypix(ispec, w, \
xmin=xmin, xmax=xmax, ymin=ymin, ymax=ymax)
if (xx.stop > xx.start) and (yy.stop > yy.start):
for k in range(nimage):
img[k, yy, xx] += pix * phot[k,i,j]
if singleimage:
return img[0]
else:
return img
#- Convenience functions
@property
def wmin(self):
"""Minimum wavelength seen by any spectrum"""
return self._wmin
@property
def wmax(self):
"""Maximum wavelength seen by any spectrum"""
return self._wmax
@property
def wmin_all(self):
"""Minimum wavelength seen by all spectra"""
return self._wmin_all
@property
def wmax_all(self):
"""Maximum wavelength seen by all spectra"""
return self._wmax_all
[docs] def projection_matrix(self, spec_range, wavelengths, xyrange, use_cache=None):
"""
Returns sparse projection matrix from flux to pixels
Inputs:
spec_range = (ispecmin, ispecmax) or scalar ispec
wavelengths = array_like wavelengths
xyrange = (xmin, xmax, ymin, ymax)
Optional inputs:
use_cache= default True, legval values will be precomputed
Usage:
xyrange = xmin, xmax, ymin, ymax
A = psf.projection_matrix(spec_range, wavelengths, xyrange)
nx = xmax-xmin
ny = ymax-ymin
img = A.dot(phot.ravel()).reshape((ny,nx))
"""
#- Matrix dimensions
if isinstance(spec_range, numbers.Integral):
specmin, specmax = spec_range, spec_range+1
else:
specmin, specmax = spec_range
xmin, xmax, ymin, ymax = xyrange
nspec = specmax - specmin
nflux = len(wavelengths)
nx = xmax - xmin
ny = ymax - ymin
if use_cache:
self.cache_params(spec_range, wavelengths)
else:
#make sure legval_dict is empty if we're not using it
self.legval_dict = None
#- Generate A
#- Start with a transposed version to fill it more efficiently
A = np.zeros( (nspec*nflux, ny*nx) )
tmp = np.zeros((ny, nx))
for ispec_cache, ispec in enumerate(range(specmin, specmax)):
for iflux, w in enumerate(wavelengths):
#- Get subimage and index slices
#have to keep track of an extra set of indicides if we're using cached values
#i.e. they have to start over again in the patch
xslice, yslice, pix = self.xypix(ispec, w, xmin=xmin, xmax=xmax,
ymin=ymin, ymax=ymax, ispec_cache=ispec_cache, iwave_cache=iflux)
#- If there is overlap with pix_range, put into sub-region of A
if pix.shape[0]>0 and pix.shape[1]>0:
tmp[yslice, xslice] = pix
ij = (ispec-specmin)*nflux + iflux
A[ij, :] = tmp.ravel()
tmp[yslice, xslice] = 0.0
#when we are finished with legval_dict clear it out
#this is important so we don't enter the cached branch of _xypix at the wrong time
self.legval_dict = None
return scipy.sparse.csr_matrix(A.T)
[docs] def cache_params(self, spec_range, wavelengths):
"""
this is implemented in specter.psf.gausshermite, everywhere else just an empty function
"""
pass
[docs] def _value(self, x, y, ispec, wavelength):
"""
this is implemented in specter.psf.gausshermite and specter.psf.spotgrid,
everywhere else just an empty function
"""
pass