Source code for jdaviz.configs.cubeviz.plugins.slice.slice

import threading
import time
import warnings

import numpy as np
from astropy.units import UnitsWarning
from glue_jupyter.bqplot.image import BqplotImageView
from glue_jupyter.bqplot.profile import BqplotProfileView
from traitlets import Bool, Float, observe, Any, Int
from specutils.spectra.spectrum1d import Spectrum1D

from jdaviz.core.events import AddDataMessage, SliceToolStateMessage, SliceSelectSliceMessage
from jdaviz.core.registries import tray_registry
from jdaviz.core.template_mixin import PluginTemplateMixin
from jdaviz.core.user_api import PluginUserApi

__all__ = ['Slice']


[docs]@tray_registry('cubeviz-slice', label="Slice", viewer_requirements='spectrum') class Slice(PluginTemplateMixin): """ See the :ref:`Slice Plugin Documentation <slice>` for more details. Only the following attributes and methods are available through the :ref:`public plugin API <plugin-apis>`: * :meth:`~jdaviz.core.template_mixin.PluginTemplateMixin.show` * :meth:`~jdaviz.core.template_mixin.PluginTemplateMixin.open_in_tray` * ``slice`` Current slice number. * ``wavelength`` Wavelength of the current slice. When setting this directly, it will update automatically to the wavelength corresponding to the nearest slice. * ``show_indicator`` Whether to show indicator in spectral viewer when slice tool is inactive. * ``show_wavelength`` Whether to show slice wavelength in label to right of indicator. """ template_file = __file__, "slice.vue" slice = Any(0).tag(sync=True) wavelength = Any(-1).tag(sync=True) wavelength_unit = Any("").tag(sync=True) min_value = Float(0).tag(sync=True) max_value = Float(100).tag(sync=True) wait = Int(200).tag(sync=True) show_indicator = Bool(True).tag(sync=True) show_wavelength = Bool(True).tag(sync=True) is_playing = Bool(False).tag(sync=True) play_interval = Int(200).tag(sync=True) # milliseconds def __init__(self, *args, **kwargs): self._default_spectrum_viewer_reference_name = kwargs.get( "spectrum_viewer_reference_name", "spectrum-viewer" ) self._default_image_viewer_reference_name = kwargs.get( "image_viewer_reference_name", "image-viewer" ) super().__init__(*args, **kwargs) self._watched_viewers = [] self._indicator_viewers = [] self._x_all = None self._player = None # initialize watching existing viewers WITH data (if initializing the plugin after data # already exists - otherwise the AddDataMessage will handle watching image viewers once # data is available) for id, viewer in self.app._viewer_store.items(): if isinstance(viewer, BqplotProfileView) or len(viewer.data()): self._watch_viewer(viewer, True) # Subscribe to requests from the helper to change the slice across all viewers self.session.hub.subscribe(self, SliceSelectSliceMessage, handler=self._on_select_slice_message) # Listen for add data events. **Note** this should only be used in # cases where there is a specific type of data expected and arbitrary # viewers are not expected to be created. That is, the expected data # in _all_ viewers should be uniform. self.session.hub.subscribe(self, AddDataMessage, handler=self._on_data_added) @property def user_api(self): return PluginUserApi(self, expose=('slice', 'wavelength', 'show_indicator', 'show_wavelength')) def _watch_viewer(self, viewer, watch=True): if isinstance(viewer, BqplotImageView): if watch and viewer not in self._watched_viewers: self._watched_viewers.append(viewer) viewer.state.add_callback('slices', self._viewer_slices_changed) elif not watch and viewer in self._watched_viewers: viewer.state.remove_callback('slices', self._viewer_slices_changed) self._watched_viewers.remove(viewer) elif isinstance(viewer, BqplotProfileView) and watch: if self._x_all is None and len(viewer.data()): # cache wavelengths so that wavelength <> slice conversion can be done efficiently self._update_data(viewer.data()[0].spectral_axis) if viewer not in self._indicator_viewers: self._indicator_viewers.append(viewer) # if the units (or data) change, we need to update internally viewer.state.add_callback("reference_data", self._update_reference_data) def _on_data_added(self, msg): if isinstance(msg.viewer, BqplotImageView): if len(msg.data.shape) == 3: self.max_value = msg.data.shape[-1] - 1 self._watch_viewer(msg.viewer, True) elif isinstance(msg.viewer, BqplotProfileView): self._watch_viewer(msg.viewer, True) def _update_reference_data(self, reference_data): if reference_data is None: return # pragma: no cover self._update_data(reference_data.get_object(cls=Spectrum1D).spectral_axis) def _update_data(self, x_all): if hasattr(x_all, 'unit'): self.wavelength_unit = str(x_all.unit) x_all = x_all.value self._x_all = x_all if self.wavelength == -1: if len(x_all): # initialize at middle of cube self.slice = int(len(x_all)/2) else: # leave in the pre-init state and don't update the wavelength/slice return # force wavelength to update from the current slider value self._on_slider_updated({'new': self.slice}) def _viewer_slices_changed(self, value): # the slices attribute on the viewer state was changed, # so we'll update the slider to match which will trigger # the slider observer (_on_slider_updated) and sync across # any other applicable viewers if len(value) == 3: self.slice = float(value[-1]) def _on_select_slice_message(self, msg): # NOTE: by setting the slice index, the observer (_on_slider_updated) # will sync across all viewers and update the wavelength with warnings.catch_warnings(): warnings.simplefilter('ignore', category=UnitsWarning) self.slice = msg.slice @observe('wavelength') def _on_wavelength_updated(self, event): # convert to float (JS handles stripping any invalid characters) try: value = float(event.get('new')) except ValueError: # do not accept changes, we'll revert via the slider # since this @change event doesn't have access to # the old value, and self.wavelength already updated # via the v-model self._on_slider_updated({'new': self.slice}) return # NOTE: by setting the index, this should recursively update the # wavelength to the nearest applicable value in _on_slider_updated self.slice = int(np.argmin(abs(value - self._x_all))) @observe('show_indicator', 'show_wavelength') def _on_setting_changed(self, event): msg = SliceToolStateMessage({event['name']: event['new']}, sender=self) self.session.hub.broadcast(msg) @observe('slice') def _on_slider_updated(self, event): if self._x_all is None: return value = int(event.get('new', int(len(self._x_all)/2))) % (int(self.max_value) + 1) self.wavelength = self._x_all[value] for viewer in self._watched_viewers: viewer.state.slices = (0, 0, value) for viewer in self._indicator_viewers: viewer._update_slice_indicator(value)
[docs] def vue_goto_first(self, *args): if self.is_playing: return self._on_slider_updated({'new': self.min_value})
[docs] def vue_goto_last(self, *args): if self.is_playing: return self._on_slider_updated({'new': self.max_value})
[docs] def vue_play_next(self, *args): if self.is_playing: return self._on_slider_updated({'new': self.slice + 1})
def _player_worker(self): ts = float(self.play_interval) * 1e-3 # ms to s while self.is_playing: self._on_slider_updated({'new': self.slice + 1}) time.sleep(ts)
[docs] def vue_play_start_stop(self, *args): if self.is_playing: # Stop if self._player: if self._player.is_alive(): self._player.join(timeout=0) self._player = None self.is_playing = False return if self._x_all is None: return # Start self.is_playing = True self._player = threading.Thread(target=self._player_worker) self._player.start()