import os
import warnings
from datetime import datetime
import astropy
import numpy as np
from astropy import units as u
from astropy.modeling.fitting import LevMarLSQFitter
from astropy.modeling import Parameter
from astropy.modeling.models import Gaussian1D
from astropy.time import Time
from glue.core.message import SubsetUpdateMessage
from ipywidgets import widget_serialization
from packaging.version import Version
from photutils.aperture import (ApertureStats, CircularAperture, EllipticalAperture,
RectangularAperture)
from traitlets import Any, Bool, Integer, List, Unicode, observe
from jdaviz.core.events import SnackbarMessage, LinkUpdatedMessage
from jdaviz.core.region_translators import regions2aperture, _get_region_from_spatial_subset
from jdaviz.core.registries import tray_registry
from jdaviz.core.template_mixin import (PluginTemplateMixin, DatasetSelectMixin,
SubsetSelect, TableMixin, PlotMixin)
from jdaviz.utils import PRIHDR_KEY
__all__ = ['SimpleAperturePhotometry']
[docs]@tray_registry('imviz-aper-phot-simple', label="Imviz Simple Aperture Photometry")
class SimpleAperturePhotometry(PluginTemplateMixin, DatasetSelectMixin, TableMixin, PlotMixin):
template_file = __file__, "aper_phot_simple.vue"
subset_items = List([]).tag(sync=True)
subset_selected = Unicode("").tag(sync=True)
subset_area = Integer().tag(sync=True)
bg_subset_items = List().tag(sync=True)
bg_subset_selected = Unicode("").tag(sync=True)
background_value = Any(0).tag(sync=True)
pixel_area = Any(0).tag(sync=True)
counts_factor = Any(0).tag(sync=True)
flux_scaling = Any(0).tag(sync=True)
result_available = Bool(False).tag(sync=True)
result_failed_msg = Unicode("").tag(sync=True)
results = List().tag(sync=True)
plot_types = List([]).tag(sync=True)
current_plot_type = Unicode().tag(sync=True)
plot_available = Bool(False).tag(sync=True)
radial_plot = Any('').tag(sync=True, **widget_serialization)
fit_radial_profile = Bool(False).tag(sync=True)
fit_results = List().tag(sync=True)
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.subset = SubsetSelect(self,
'subset_items',
'subset_selected',
default_text=None,
filters=['is_spatial', 'is_not_composite', 'is_not_annulus'])
self.bg_subset = SubsetSelect(self,
'bg_subset_items',
'bg_subset_selected',
default_text='Manual',
manual_options=['Manual'],
filters=['is_spatial', 'is_not_composite'])
headers = ['xcenter', 'ycenter', 'sky_center',
'sum', 'sum_aper_area',
'aperture_sum_counts', 'aperture_sum_counts_err',
'aperture_sum_mag',
'min', 'max', 'mean', 'median', 'mode', 'std', 'mad_std', 'var',
'biweight_location', 'biweight_midvariance', 'fwhm',
'semimajor_sigma', 'semiminor_sigma', 'orientation', 'eccentricity',
'data_label', 'subset_label']
self.table.headers_avail = headers
self.table.headers_visible = headers
self._selected_data = None
self._selected_subset = None
self.plot_types = ["Curve of Growth", "Radial Profile", "Radial Profile (Raw)"]
self.current_plot_type = self.plot_types[0]
self._fitted_model_name = 'phot_radial_profile'
self.plot.add_line('line', color='gray', marker_size=32)
self.plot.add_scatter('scatter', color='gray', default_size=1)
self.plot.add_line('fit_line', color='magenta', line_style='dashed')
self.session.hub.subscribe(self, SubsetUpdateMessage, handler=self._on_subset_update)
self.session.hub.subscribe(self, LinkUpdatedMessage, handler=self._on_link_update)
@observe('dataset_selected')
def _dataset_selected_changed(self, event={}):
try:
self._selected_data = self.dataset.selected_dc_item
if self._selected_data is None:
return
self.counts_factor = 0
self.pixel_area = 0
self.flux_scaling = 0
# Extract telescope specific unit conversion factors, if applicable.
meta = self._selected_data.meta.copy()
if PRIHDR_KEY in meta:
meta.update(meta[PRIHDR_KEY])
del meta[PRIHDR_KEY]
if 'telescope' in meta:
telescope = meta['telescope']
else:
telescope = meta.get('TELESCOP', '')
if telescope == 'JWST':
if 'photometry' in meta and 'pixelarea_arcsecsq' in meta['photometry']:
self.pixel_area = meta['photometry']['pixelarea_arcsecsq']
if 'bunit_data' in meta and meta['bunit_data'] == u.Unit("MJy/sr"):
# Hardcode the flux conversion factor from MJy to ABmag
self.flux_scaling = 0.003631
elif telescope == 'HST':
# TODO: Add more HST support, as needed.
# HST pixel scales are from instrument handbooks.
# This is really not used because HST data does not have sr in unit.
# This is only for completeness.
# For counts conversion, PHOTFLAM is used to convert "counts" to flux manually,
# which is the opposite of JWST, so we just do not do it here.
instrument = meta.get('INSTRUME', '').lower()
detector = meta.get('DETECTOR', '').lower()
if instrument == 'acs':
if detector == 'wfc':
self.pixel_area = 0.05 * 0.05
elif detector == 'hrc': # pragma: no cover
self.pixel_area = 0.028 * 0.025
elif detector == 'sbc': # pragma: no cover
self.pixel_area = 0.034 * 0.03
elif instrument == 'wfc3' and detector == 'uvis': # pragma: no cover
self.pixel_area = 0.04 * 0.04
except Exception as e:
self._selected_data = None
self.hub.broadcast(SnackbarMessage(
f"Failed to extract {self.dataset_selected}: {repr(e)}",
color='error', sender=self))
# Update self._selected_subset with the new self._selected_data
# and auto-populate background, if applicable.
self._subset_selected_changed()
def _on_subset_update(self, msg):
if self.dataset_selected == '' or self.subset_selected == '':
return
sbst = msg.subset
if sbst.label == self.subset_selected and sbst.data.label == self.dataset_selected:
self._subset_selected_changed()
elif sbst.label == self.bg_subset_selected and sbst.data.label == self.dataset_selected:
self._bg_subset_selected_changed()
def _on_link_update(self, msg):
if self.dataset_selected == '' or self.subset_selected == '':
return
# Force background auto-calculation to update when linking has changed.
self._subset_selected_changed()
@observe('subset_selected')
def _subset_selected_changed(self, event={}):
subset_selected = event.get('new', self.subset_selected)
if self._selected_data is None or subset_selected == '':
return
try:
self._selected_subset = _get_region_from_spatial_subset(self, subset_selected)
self._selected_subset.meta['label'] = subset_selected
self.subset_area = int(np.ceil(self._selected_subset.area))
except Exception as e:
self._selected_subset = None
self.hub.broadcast(SnackbarMessage(
f"Failed to extract {subset_selected}: {repr(e)}", color='error', sender=self))
else:
self._bg_subset_selected_changed()
def _calc_bg_subset_median(self, reg):
# Basically same way image stats are calculated in vue_do_aper_phot()
# except here we only care about one stat for the background.
data = self._selected_data
comp = data.get_component(data.main_components[0])
aper_mask_stat = reg.to_mask(mode='center')
img_stat = aper_mask_stat.get_values(comp.data, mask=None)
# photutils/background/_utils.py --> nanmedian()
return np.nanmedian(img_stat) # Naturally in data unit
@observe('bg_subset_selected')
def _bg_subset_selected_changed(self, event={}):
bg_subset_selected = event.get('new', self.bg_subset_selected)
if bg_subset_selected == 'Manual':
# we'll later access the user's self.background_value directly
return
try:
reg = _get_region_from_spatial_subset(self, bg_subset_selected)
self.background_value = self._calc_bg_subset_median(reg)
except Exception as e:
self.background_value = 0
self.hub.broadcast(SnackbarMessage(
f"Failed to extract {bg_subset_selected}: {repr(e)}", color='error', sender=self))
[docs] def vue_do_aper_phot(self, *args, **kwargs):
if self._selected_data is None or self._selected_subset is None:
self.hub.broadcast(SnackbarMessage(
"No data for aperture photometry", color='error', sender=self))
return
data = self._selected_data
reg = self._selected_subset
xcenter = reg.center.x
ycenter = reg.center.y
# Reset last fitted model
fit_model = None
if self._fitted_model_name in self.app.fitted_models:
del self.app.fitted_models[self._fitted_model_name]
try:
comp = data.get_component(data.main_components[0])
try:
bg = float(self.background_value)
except ValueError: # Clearer error message
raise ValueError('Missing or invalid background value')
if data.coords is not None:
sky_center = data.coords.pixel_to_world(xcenter, ycenter)
else:
sky_center = None
aperture = regions2aperture(reg)
include_pixarea_fac = False
include_counts_fac = False
include_flux_scale = False
comp_data = comp.data
if comp.units:
img_unit = u.Unit(comp.units)
bg = bg * img_unit
comp_data = comp_data << img_unit
if u.sr in img_unit.bases: # TODO: Better way to detect surface brightness unit?
try:
pixarea = float(self.pixel_area)
except ValueError: # Clearer error message
raise ValueError('Missing or invalid pixel area')
if not np.allclose(pixarea, 0):
include_pixarea_fac = True
if img_unit != u.count:
try:
ctfac = float(self.counts_factor)
except ValueError: # Clearer error message
raise ValueError('Missing or invalid counts conversion factor')
if not np.allclose(ctfac, 0):
include_counts_fac = True
try:
flux_scale = float(self.flux_scaling)
except ValueError: # Clearer error message
raise ValueError('Missing or invalid flux scaling')
if not np.allclose(flux_scale, 0):
include_flux_scale = True
phot_aperstats = ApertureStats(comp_data, aperture, wcs=data.coords, local_bkg=bg)
phot_table = phot_aperstats.to_table(columns=(
'id', 'sum', 'sum_aper_area',
'min', 'max', 'mean', 'median', 'mode', 'std', 'mad_std', 'var',
'biweight_location', 'biweight_midvariance', 'fwhm', 'semimajor_sigma',
'semiminor_sigma', 'orientation', 'eccentricity')) # Some cols excluded, add back as needed. # noqa
rawsum = phot_table['sum'][0]
if include_pixarea_fac:
pixarea = pixarea * (u.arcsec * u.arcsec / (u.pix * u.pix))
# NOTE: Sum already has npix value encoded, so we simply apply the npix unit here.
pixarea_fac = (u.pix * u.pix) * pixarea.to(u.sr / (u.pix * u.pix))
phot_table['sum'] = [rawsum * pixarea_fac]
else:
pixarea_fac = None
if include_counts_fac:
ctfac = ctfac * (rawsum.unit / u.count)
sum_ct = rawsum / ctfac
sum_ct_err = np.sqrt(sum_ct.value) * sum_ct.unit
else:
ctfac = None
sum_ct = None
sum_ct_err = None
if include_flux_scale:
flux_scale = flux_scale * phot_table['sum'][0].unit
sum_mag = -2.5 * np.log10(phot_table['sum'][0] / flux_scale) * u.mag
else:
flux_scale = None
sum_mag = None
# Extra info beyond photutils.
phot_table.add_columns(
[xcenter * u.pix, ycenter * u.pix, sky_center,
bg, pixarea_fac, sum_ct, sum_ct_err, ctfac, sum_mag, flux_scale, data.label,
reg.meta.get('label', ''), Time(datetime.utcnow())],
names=['xcenter', 'ycenter', 'sky_center', 'background', 'pixarea_tot',
'aperture_sum_counts', 'aperture_sum_counts_err', 'counts_fac',
'aperture_sum_mag', 'flux_scaling',
'data_label', 'subset_label', 'timestamp'],
indexes=[1, 1, 1, 1, 3, 3, 3, 3, 3, 3, 18, 18, 18])
try:
phot_table['id'][0] = self.table._qtable['id'].max() + 1
self.table.add_item(phot_table)
except Exception: # Discard incompatible QTable
self.table.clear_table()
phot_table['id'][0] = 1
self.table.add_item(phot_table)
# Plots.
line = self.plot.marks['line']
sc = self.plot.marks['scatter']
fit_line = self.plot.marks['fit_line']
if self.current_plot_type == "Curve of Growth":
self.plot.figure.title = 'Curve of growth from aperture center'
x_arr, sum_arr, x_label, y_label = _curve_of_growth(
comp_data, (xcenter, ycenter), aperture, phot_table['sum'][0],
wcs=data.coords, background=bg, pixarea_fac=pixarea_fac)
line.x, line.y = x_arr, sum_arr
self.plot.clear_marks('scatter', 'fit_line')
self.plot.figure.axes[0].label = x_label
self.plot.figure.axes[1].label = y_label
else: # Radial profile
self.plot.figure.axes[0].label = 'pix'
self.plot.figure.axes[1].label = comp.units or 'Value'
if self.current_plot_type == "Radial Profile":
self.plot.figure.title = 'Radial profile from aperture center'
x_data, y_data = _radial_profile(
phot_aperstats.data_cutout, phot_aperstats.bbox, (xcenter, ycenter),
raw=False)
line.x, line.y = x_data, y_data
self.plot.clear_marks('scatter')
else: # Radial Profile (Raw)
self.plot.figure.title = 'Raw radial profile from aperture center'
x_data, y_data = _radial_profile(
phot_aperstats.data_cutout, phot_aperstats.bbox, (xcenter, ycenter),
raw=True)
sc.x, sc.y = x_data, y_data
self.plot.clear_marks('line')
# Fit Gaussian1D to radial profile data.
if self.fit_radial_profile:
fitter = LevMarLSQFitter()
y_max = y_data.max()
x_mean = x_data[np.where(y_data == y_max)].mean()
std = 0.5 * (phot_table['semimajor_sigma'][0] +
phot_table['semiminor_sigma'][0])
if isinstance(std, u.Quantity):
std = std.value
gs = Gaussian1D(amplitude=y_max, mean=x_mean, stddev=std,
fixed={'amplitude': True},
bounds={'amplitude': (y_max * 0.5, y_max)})
if Version(astropy.__version__) < Version('5.2'):
fitter_kw = {}
else:
fitter_kw = {'filter_non_finite': True}
with warnings.catch_warnings(record=True) as warns:
fit_model = fitter(gs, x_data, y_data, **fitter_kw)
if len(warns) > 0:
msg = os.linesep.join([str(w.message) for w in warns])
self.hub.broadcast(SnackbarMessage(
f"Radial profile fitting: {msg}", color='warning', sender=self))
y_fit = fit_model(x_data)
self.app.fitted_models[self._fitted_model_name] = fit_model
fit_line.x, fit_line.y = x_data, y_fit
else:
self.plot.clear_marks('fit_line')
except Exception as e: # pragma: no cover
self.plot.clear_all_marks()
msg = f"Aperture photometry failed: {repr(e)}"
self.hub.broadcast(SnackbarMessage(msg, color='error', sender=self))
self.result_failed_msg = msg
else:
self.result_failed_msg = ''
# Parse results for GUI.
tmp = []
for key in phot_table.colnames:
if key in ('id', 'data_label', 'subset_label', 'background', 'pixarea_tot',
'counts_fac', 'aperture_sum_counts_err', 'flux_scaling', 'timestamp'):
continue
x = phot_table[key][0]
if (isinstance(x, (int, float, u.Quantity)) and
key not in ('xcenter', 'ycenter', 'sky_center', 'sum_aper_area',
'aperture_sum_counts', 'aperture_sum_mag')):
tmp.append({'function': key, 'result': f'{x:.4e}'})
elif key == 'sky_center' and x is not None:
tmp.append({'function': 'RA center', 'result': f'{x.ra.deg:.6f} deg'})
tmp.append({'function': 'Dec center', 'result': f'{x.dec.deg:.6f} deg'})
elif key in ('xcenter', 'ycenter', 'sum_aper_area'):
tmp.append({'function': key, 'result': f'{x:.1f}'})
elif key == 'aperture_sum_counts' and x is not None:
tmp.append({'function': key, 'result':
f'{x:.4e} ({phot_table["aperture_sum_counts_err"][0]:.4e})'})
elif key == 'aperture_sum_mag' and x is not None:
tmp.append({'function': key, 'result': f'{x:.3f}'})
else:
tmp.append({'function': key, 'result': str(x)})
# Also display fit results
fit_tmp = []
if fit_model is not None and isinstance(fit_model, Gaussian1D):
for param in ('mean', 'fwhm', 'amplitude'):
p_val = getattr(fit_model, param)
if isinstance(p_val, Parameter):
p_val = p_val.value
fit_tmp.append({'function': param, 'result': f'{p_val:.4e}'})
self.results = tmp
self.fit_results = fit_tmp
self.result_available = True
self.plot_available = True
# NOTE: These are hidden because the APIs are for internal use only
# but we need them as a separate functions for unit testing.
def _radial_profile(radial_cutout, reg_bb, centroid, raw=False):
"""Calculate radial profile.
Parameters
----------
radial_cutout : ndarray
Cutout image from ``ApertureStats``.
reg_bb : obj
Bounding box from ``ApertureStats``.
centroid : tuple of int
``ApertureStats`` centroid or desired center in ``(x, y)``.
raw : bool
If `True`, returns raw data points for scatter plot.
Otherwise, use ``imexam`` algorithm for a clean plot.
"""
reg_ogrid = np.ogrid[reg_bb.iymin:reg_bb.iymax, reg_bb.ixmin:reg_bb.ixmax]
radial_dx = reg_ogrid[1] - centroid[0]
radial_dy = reg_ogrid[0] - centroid[1]
radial_r = np.hypot(radial_dx, radial_dy)
# Sometimes the mask is smaller than radial_r
if radial_cutout.shape != reg_bb.shape:
radial_r = radial_r[:radial_cutout.shape[0], :radial_cutout.shape[1]]
radial_r = radial_r[~radial_cutout.mask].ravel() # pix
radial_img = radial_cutout.compressed() # data unit
if raw:
i_arr = np.argsort(radial_r)
x_arr = radial_r[i_arr]
y_arr = radial_img[i_arr]
else:
# This algorithm is from the imexam package,
# see licenses/IMEXAM_LICENSE.txt for more details
radial_r = np.rint(radial_r).astype(int)
y_arr = np.bincount(radial_r, radial_img) / np.bincount(radial_r)
x_arr = np.arange(y_arr.size)
return x_arr, y_arr
def _curve_of_growth(data, centroid, aperture, final_sum, wcs=None, background=0, n_datapoints=10,
pixarea_fac=None):
"""Calculate curve of growth for aperture photometry.
Parameters
----------
data : ndarray or `~astropy.units.Quantity`
Data for the calculation.
centroid : tuple of int
``ApertureStats`` centroid or desired center in ``(x, y)``.
aperture : obj
``photutils`` aperture to use, except its center will be
changed to the given ``centroid``. This is because the aperture
might be hand-drawn and a more accurate centroid has been
recalculated separately.
final_sum : float or `~astropy.units.Quantity`
Aperture sum that is already calculated in the
main plugin above.
wcs : obj or `None`
Supported WCS objects or `None`.
background : float or `~astropy.units.Quantity`
Background to subtract, if any. Unit must match ``data``.
n_datapoints : int
Number of data points in the curve.
pixarea_fac : float or `None`
For ``flux_unit/sr`` to ``flux_unit`` conversion.
Returns
-------
x_arr : ndarray
Data for X-axis of the curve.
sum_arr : ndarray or `~astropy.units.Quantity`
Data for Y-axis of the curve.
x_label, y_label : str
X- and Y-axis labels, respectively.
Raises
------
TypeError
Unsupported aperture.
"""
n_datapoints += 1 # n + 1
if isinstance(aperture, CircularAperture):
x_label = 'Radius (pix)'
x_arr = np.linspace(0, aperture.r, num=n_datapoints)[1:]
aper_list = [CircularAperture(centroid, cur_r) for cur_r in x_arr[:-1]]
elif isinstance(aperture, EllipticalAperture):
x_label = 'Semimajor axis (pix)'
x_arr = np.linspace(0, aperture.a, num=n_datapoints)[1:]
a_arr = x_arr[:-1]
b_arr = aperture.b * a_arr / aperture.a
aper_list = [EllipticalAperture(centroid, cur_a, cur_b, theta=aperture.theta)
for (cur_a, cur_b) in zip(a_arr, b_arr)]
elif isinstance(aperture, RectangularAperture):
x_label = 'Width (pix)'
x_arr = np.linspace(0, aperture.w, num=n_datapoints)[1:]
w_arr = x_arr[:-1]
h_arr = aperture.h * w_arr / aperture.w
aper_list = [RectangularAperture(centroid, cur_w, cur_h, theta=aperture.theta)
for (cur_w, cur_h) in zip(w_arr, h_arr)]
else:
raise TypeError(f'Unsupported aperture: {aperture}')
sum_arr = [ApertureStats(data, cur_aper, wcs=wcs, local_bkg=background).sum
for cur_aper in aper_list]
if isinstance(sum_arr[0], u.Quantity):
sum_arr = u.Quantity(sum_arr)
else:
sum_arr = np.array(sum_arr)
if pixarea_fac is not None:
sum_arr = sum_arr * pixarea_fac
sum_arr = np.append(sum_arr, final_sum)
if isinstance(sum_arr, u.Quantity):
y_label = sum_arr.unit.to_string()
sum_arr = sum_arr.value # bqplot does not like Quantity
else:
y_label = 'Value'
return x_arr, sum_arr, x_label, y_label