Source code for squirrel.template

"""This module contains the class to store stellar and other templates and process
them."""

import numpy as np
from copy import deepcopy
from .data import Spectra


[docs] class Template(Spectra): """A class to store stellar and other templates and process them.""" def __init__( self, wavelengths, flux, wavelength_unit, fwhm, flux_unit="arbitrary", ): """Initialize the template. :param wavelengths: wavelengths of the template :type wavelengths: numpy.ndarray :param flux: flux of the template :type flux: numpy.ndarray :param wavelength_unit: unit of the wavelength :type wavelength_unit: str :param fwhm: full width at half maximum :type fwhm: float :param flux_unit: unit of the flux :type flux_unit: str """ # Ensure the flux array has at least two dimensions if len(flux.shape) == 1: flux = flux[:, np.newaxis] # Initialize the parent class Spectra with the provided parameters super(Template, self).__init__( wavelengths=wavelengths, flux=flux, wavelength_unit=wavelength_unit, fwhm=fwhm, z_lens=0.0, z_source=0.0, flux_unit=flux_unit, noise=None, )
[docs] def merge(self, other): """Merge the template with another template. :param other: template to merge with :type other: squirrel.template.Template :return: A new Template instance with merged flux :rtype: squirrel.template.Template """ # Ensure the wavelength units and FWHM match assert ( self.wavelength_unit == other.wavelength_unit ), "Wavelength units do not match" assert self.fwhm == other.fwhm, "FWHM do not match" # Ensure the wavelengths match np.testing.assert_equal( self.wavelengths, other.wavelengths, err_msg="Wavelengths do not match" ) # Create a deep copy of the current template new_template = deepcopy(self) # Concatenate the flux arrays along the second axis other_flux = np.atleast_2d(other.flux) if other_flux.shape[0] == 1: # If the other flux is a single row, transpose it to match the current template other_flux = other_flux.T self_flux = np.atleast_2d(self.flux) if self_flux.shape[0] == 1: # If the current flux is a single row, transpose it to match the other template self_flux = self_flux.T new_template.flux = np.concatenate((self_flux, other_flux), axis=1) return new_template
def __and__(self, other): """Merge the template with another template using the & operator. :param other: template to merge with :type other: squirrel.template.Template :return: A new Template instance with merged flux :rtype: squirrel.template.Template """ # Create a deep copy of the current template and merge with the other template new_template = deepcopy(self) return new_template.merge(other) def __iand__(self, other): """Merge the template with another template using the &= operator. :param other: template to merge with :type other: squirrel.template.Template :return: The current Template instance with merged flux :rtype: squirrel.template.Template """ # Merge the current template with the other template return self.merge(other)
[docs] def combine_weighted(self, weights): """Combine the templates into one single template using weighted sum. :param weights: weights for each template :type weights: numpy.array :return: A new Template instance with combined flux :rtype: squirrel.template.Template """ # Create a deep copy of the current template new_template = deepcopy(self) # Compute the weighted sum of the flux arrays flux = new_template.flux @ weights # Normalize the flux by its median value flux /= np.median(flux) # Ensure the flux array has at least two dimensions if len(flux.shape) == 1: flux = flux[:, np.newaxis] # Update the flux of the new template new_template.flux = flux return new_template
[docs] def discard_zero_weights(self, weights): """Discard the templates with zero weights. :param weights: weights for each template :type weights: numpy.array :return: A new Template instance with non-zero weighted flux :rtype: squirrel.template.Template """ # Create a deep copy of the current template new_template = deepcopy(self) # If, e.g., a background_spectra component (`sky` in ppxf) was used, # the weights can be more numerous than the number of templates, so we need to check that the weights array is not longer than the number of templates if len(weights) < new_template.flux.shape[1]: raise ValueError( "The number of weights cannot be smaller than the number of templates." ) elif len(weights) > new_template.flux.shape[1]: weights_ = weights[: new_template.flux.shape[1]] else: weights_ = weights # Select the flux columns corresponding to non-zero weights non_zero_indices = np.where(weights_ > 0)[0] print("non_zero_indices", non_zero_indices) new_template.flux = new_template.flux[:, non_zero_indices] return new_template