Single Baseline pI FRF SNR¶

by Josh Dillon and Tyler Cox, last updated May 26, 2026

This notebook takes corner-turned, calibrated, redundantly-averaged visibility data, forms pseudo-Stokes pI, and computes delay-filtered, fringe-rate-filtered SNR waterfalls. The results are written out as uvh5 files to be combined across baselines to look for residual structure that fringes like the main beam.

Here's a set of links to skip to particular figures and tables:

• Figure 1: Delay-Filtered pI SNR in Fringe Rate Space¶

• Figure 2: Delay-Filtered pI SNR Waterfall¶

• Figure 3: Delay-Filtered pI SNR Histogram¶

• Figure 4: Delay+FR Filtered pI SNR Waterfall¶

• Figure 5: Delay+FR Filtered pI SNR Histogram¶

In [1]:
import time
tstart = time.time()
!hostname
!date
herapost012
Sat May 30 11:52:23 MDT 2026
In [2]:
import os
os.environ['HDF5_USE_FILE_LOCKING'] = 'FALSE'
import h5py
import hdf5plugin
import numpy as np
import yaml
import copy
import re
import glob
import matplotlib
from astropy import units, constants
from scipy import interpolate
from scipy.signal.windows import blackmanharris
from pyuvdata import UVFlag
from hera_cal import io, utils, flag_utils, red_groups, polfilt
from hera_cal.frf import sky_frates_single, get_FR_buffer_from_spectra, get_m2f_mixer
from hera_filters.dspec import dpss_operator, fourier_filter
import hera_pspec as hp
import uvtools
import matplotlib.pyplot as plt
from IPython.display import display
%matplotlib inline
In [3]:
RED_AVG_FILE = os.environ.get("RED_AVG_FILE", None)
# RED_AVG_FILE = "/lustre/aoc/projects/hera/h6c-analysis/IDR3_2/2459935/zen.2459935.21408.sum.smooth_calibrated.red_avg.uvh5" # TODO

CORNER_TURN_MAP_YAML = os.environ.get("CORNER_TURN_MAP_YAML", 
                                      os.path.join(os.path.dirname(RED_AVG_FILE), "single_baseline_files/corner_turn_map.yaml"))
FRF_SNR_SUFFIX = os.environ.get("FRF_SNR_SUFFIX", ".pI_FRF_SNR.uvh5")
SAVE_DLY_SNR = os.environ.get("SAVE_DLY_SNR", "TRUE").upper() == "TRUE"
DLY_SNR_SUFFIX = os.environ.get("DLY_SNR_SUFFIX", ".pI_DLYFILT_SNR.uvh5")
SAVE_FRF_SNR = os.environ.get("SAVE_FRF_SNR", "TRUE").upper() == "TRUE"

FM_LOW_FREQ = float(os.environ.get("FM_LOW_FREQ", 87.5))  # in MHz
FM_HIGH_FREQ = float(os.environ.get("FM_HIGH_FREQ", 108.0))  # in MHz

FILTER_DELAY = float(os.environ.get("FILTER_DELAY", 750))  # in ns
EIGENVAL_CUTOFF = float(os.environ.get("EIGENVAL_CUTOFF", 1e-12))

FR_SPECTRA_FILE = os.environ.get("FR_SPECTRA_FILE", 
                                 "/lustre/aoc/projects/hera/h6c-analysis/IDR3/beam_simulation_products/spectra_cache_hera_core.h5")
AUTO_FR_SPECTRUM_FILE = os.environ.get("AUTO_FR_SPECTRUM_FILE",
                                       "/lustre/aoc/projects/hera/zmartino/hera_frf/spectra_cache/spectra_cache_hera_auto.h5")
XTALK_FRATE = float(os.environ.get("XTALK_FRATE", 0.01))  # in mHz
SKIP_FR0_OVERLAP_BASELINES = os.environ.get("SKIP_FR0_OVERLAP_BASELINES", "FALSE").upper() == "TRUE"

FR_QUANTILE_LOW = float(os.environ.get("FR_QUANTILE_LOW", 0.05))
FR_QUANTILE_HIGH = float(os.environ.get("FR_QUANTILE_HIGH", 0.95))

MIN_SAMP_FRAC = float(os.environ.get("MIN_SAMP_FRAC", 0.1))
RIDGE_ALPHA = float(os.environ.get("RIDGE_ALPHA", 1e-12))
LEVERAGE_CAP = float(os.environ.get("LEVERAGE_CAP", 0.999))

APPLY_PRIOR_FLAGS = os.environ.get("APPLY_PRIOR_FLAGS", "TRUE").upper() == "TRUE"
PRIOR_FLAG_SUFFIX = os.environ.get("PRIOR_FLAG_SUFFIX", ".flag_waterfall_round_3.h5")
APPLY_WHERE_INPAINTED_FLAGS = os.environ.get("APPLY_WHERE_INPAINTED_FLAGS", "TRUE").upper() == "TRUE"
WHERE_INPAINTED_SUFFIX = os.environ.get("WHERE_INPAINTED_SUFFIX", ".where_inpainted.h5")

SUBTRACT_POLARIZED_SOURCE = os.environ.get("SUBTRACT_POLARIZED_SOURCE", "FALSE").upper() == "TRUE"
SOURCE_YAML = os.environ.get(
    "SOURCE_YAML", "/lustre/aoc/projects/hera/h6c-analysis/IDR3_2/metadata/sources.yaml"
)

for setting in ['RED_AVG_FILE', 'CORNER_TURN_MAP_YAML', 'FRF_SNR_SUFFIX', 'DLY_SNR_SUFFIX', 
                'FR_SPECTRA_FILE', 'AUTO_FR_SPECTRUM_FILE', 'PRIOR_FLAG_SUFFIX', 'WHERE_INPAINTED_SUFFIX']:
    print(f'{setting} = "{eval(setting)}"')
for setting in []:
    print(f'{setting} = {eval(setting)}')
for setting in ['FM_LOW_FREQ', 'FM_HIGH_FREQ', 'FILTER_DELAY', 'EIGENVAL_CUTOFF', 'XTALK_FRATE',
                'SAVE_DLY_SNR', 'SAVE_FRF_SNR', 'FR_QUANTILE_LOW', 'FR_QUANTILE_HIGH', 'MIN_SAMP_FRAC',
                'APPLY_PRIOR_FLAGS', 'APPLY_WHERE_INPAINTED_FLAGS', 'SKIP_FR0_OVERLAP_BASELINES']:
    print(f'{setting} = {eval(setting)}')
RED_AVG_FILE = "/lustre/aoc/projects/hera/h6c-analysis/IDR3_2/2459917/zen.2459917.25319.sum.smooth_calibrated.red_avg.uvh5"
CORNER_TURN_MAP_YAML = "/lustre/aoc/projects/hera/h6c-analysis/IDR3_2/2459917/single_baseline_files/corner_turn_map.yaml"
FRF_SNR_SUFFIX = ".pI_FRF_SNR.uvh5"
DLY_SNR_SUFFIX = ".pI_DLYFILT_SNR.uvh5"
FR_SPECTRA_FILE = "/lustre/aoc/projects/hera/h6c-analysis/IDR3/beam_simulation_products/spectra_cache_hera_core.h5"
AUTO_FR_SPECTRUM_FILE = "/lustre/aoc/projects/hera/zmartino/hera_frf/spectra_cache/spectra_cache_hera_auto.h5"
PRIOR_FLAG_SUFFIX = ".flag_waterfall_round_3.h5"
WHERE_INPAINTED_SUFFIX = ".where_inpainted.h5"
FM_LOW_FREQ = 87.5
FM_HIGH_FREQ = 108.0
FILTER_DELAY = 750.0
EIGENVAL_CUTOFF = 1e-12
XTALK_FRATE = 0.01
SAVE_DLY_SNR = True
SAVE_FRF_SNR = False
FR_QUANTILE_LOW = 0.05
FR_QUANTILE_HIGH = 0.95
MIN_SAMP_FRAC = 0.1
APPLY_PRIOR_FLAGS = True
APPLY_WHERE_INPAINTED_FLAGS = True
SKIP_FR0_OVERLAP_BASELINES = False

Preliminaries¶

In [4]:
with open(CORNER_TURN_MAP_YAML, 'r') as file:
    corner_turn_map = yaml.unsafe_load(file)
In [5]:
# get autocorrelations
# TODO: generalize for not-previously inpainted data
all_outfiles = [outfile.replace('.uvh5', '.inpainted.uvh5') for outfiles in corner_turn_map['files_to_outfiles_map'].values() for outfile in outfiles]
for outfile in all_outfiles:
    match = re.search(r'\.(\d+)_(\d+)\.', os.path.basename(outfile))
    if match and match.group(1) == match.group(2):
        hd_autos = io.HERAData(outfile)
        autos, _, auto_nsamples = hd_autos.read(polarizations=['ee', 'nn'])
        break
In [6]:
# Load and combine prior flags if enabled
if APPLY_PRIOR_FLAGS:
    jdstr = [s for s in os.path.basename(RED_AVG_FILE).split('.') if s.isnumeric()][0]
    flag_dir = os.path.dirname(CORNER_TURN_MAP_YAML)
    flag_pattern = os.path.join(flag_dir, f'zen.{jdstr}*{PRIOR_FLAG_SUFFIX}')
    prior_flag_files = sorted(glob.glob(flag_pattern))
    
    if len(prior_flag_files) == 0:
        raise ValueError(f'APPLY_PRIOR_FLAGS is True but no files matched {flag_pattern}')
    else:
        print(f'Found {len(prior_flag_files)} prior flag files:')
        for f in prior_flag_files:
            print(f'  {os.path.basename(f)}')
        prior_flags = np.any([np.all(UVFlag(flag_file).flag_array, axis=-1) for flag_file in prior_flag_files], axis=0)
        print(f'Combined prior flags: {np.mean(prior_flags):.3%} flagged.')
Found 1 prior flag files:
  zen.2459917.flag_waterfall_round_3.h5
Combined prior flags: 27.792% flagged.
In [7]:
if SUBTRACT_POLARIZED_SOURCE:
    with open(SOURCE_YAML) as f:
        source_config = yaml.safe_load(f)

    model_dir = os.path.dirname(CORNER_TURN_MAP_YAML)  # directory containing the input data

    model_files = []
    for source in source_config["sources"]:
        fsname = source["name"].replace(" ", "_").replace("-", "_")
        model_files.extend(
            model_file for model_name in ("fg_model", "rm_model", "scint_model")
            if os.path.exists(
                model_file := os.path.join(model_dir, f"phased.{jdstr}.{fsname}.{model_name}.uvh5")
            )
        )

    print(f"Found {len(model_files)} model files")
    for model_file in model_files:
        print(f"  {os.path.basename(model_file)}")
Found 3 model files
  phased.2459917.PSR_J0628_28.fg_model.uvh5
  phased.2459917.PSR_J0628_28.rm_model.uvh5
  phased.2459917.PSR_J0628_28.scint_model.uvh5

Define Plotting and Helper Functions¶

In [8]:
def compute_mb_fr_ranges(hd, antpair):
    '''Compute main beam fringe rate ranges from beam simulation spectra.
    Returns per-frequency arrays: freqs in MHz, FR bounds in mHz.'''
    with h5py.File(FR_SPECTRA_FILE, "r") as h5f:
        metadata = h5f["metadata"]
        bl_to_index_map = {tuple(ap): int(index) for index, antpairs
                           in metadata["baseline_groups"].items() for ap in antpairs}
        spectrum_freqs = metadata["frequencies_MHz"][()] * 1e6
        m_modes = metadata["erh_mode_integer_index"][()]
        this_red_group = red_groups.RedundantGroups.from_antpos(hd.antpos)[antpair]
        for ap in this_red_group:
            if ap in bl_to_index_map:
                mmode_spectrum = h5f["erh_mode_power_spectrum"][:, :, bl_to_index_map[ap]]
                break
            elif ap[::-1] in bl_to_index_map:
                mmode_spectrum = h5f["erh_mode_power_spectrum"][:, :, bl_to_index_map[ap[::-1]]]
                m_modes = m_modes * -1
                break
        else:
            raise KeyError(f'{antpair}, nor any baseline redundant with it, was found in bl_to_index_map.')

    # Build mixing matrix from m-modes to fringe rates
    full_times = hd.times
    times_ks = (full_times - full_times[0] + np.median(np.diff(full_times))) * units.day.to(units.ks)
    filt_frates = np.fft.fftshift(np.fft.fftfreq(times_ks.size, d=np.median(np.diff(times_ks))))
    _m2f_mixer = get_m2f_mixer(times_ks, m_modes)

    # Vectorized: FR spectrum for every spectrum_freq channel at once
    # _m2f_mixer: (n_frates, n_mmodes), mmode_spectrum: (n_mmodes, n_spectrum_freqs)
    fr_spectra = np.abs(np.einsum("fm,mc,mf->fc", _m2f_mixer, mmode_spectrum, _m2f_mixer.T.conj()))
    # Normalize each channel
    fr_spectra /= fr_spectra.sum(axis=0, keepdims=True)

    # Compute quantile bounds per spectrum_freq channel
    cumsum = np.cumsum(fr_spectra, axis=0)
    spec_tops = np.array([np.interp(FR_QUANTILE_HIGH, cumsum[:, c], filt_frates)
                           for c in range(len(spectrum_freqs))])
    spec_bottoms = np.array([np.interp(FR_QUANTILE_LOW, cumsum[:, c], filt_frates)
                              for c in range(len(spectrum_freqs))])

    # Interpolate from spectrum_freqs to data freqs (with extrapolation)
    mb_frate_tops = interpolate.interp1d(spectrum_freqs, spec_tops, fill_value='extrapolate')(hd.freqs)
    mb_frate_bottoms = interpolate.interp1d(spectrum_freqs, spec_bottoms, fill_value='extrapolate')(hd.freqs)

    return (hd.freqs / 1e6, mb_frate_tops, mb_frate_bottoms)


def compute_sky_fr_ranges(hd, antpair):
    '''Compute sky fringe rate ranges using analytic sky_frates_single + empirical buffer.
    Returns per-frequency arrays in MHz and mHz.'''
    blvec = hd.antpos[antpair[0]] - hd.antpos[antpair[1]]
    latitude = hd.telescope.location.lat.rad
    sky_centers, sky_hws = sky_frates_single(hd.freqs, blvec, latitude)  # mHz
    fr_buffer = get_FR_buffer_from_spectra(AUTO_FR_SPECTRUM_FILE, hd.times, hd.freqs, gauss_fit_buffer_cut=1e-5)
    return (hd.freqs / 1e6, sky_centers + sky_hws + fr_buffer, sky_centers - sky_hws - fr_buffer)


def overlaps_FR0(bands_bl, mb_frate_tops, mb_frate_bottoms):
    '''Check whether any band's main beam fringe rates overlap FR=0 ± XTALK_FRATE.'''
    for band in bands_bl:
        if band is None:
            continue
        # check if all FRs are above the FR=0 band
        if not ((np.all(mb_frate_tops[band] > XTALK_FRATE)) and (np.all(mb_frate_bottoms[band] > XTALK_FRATE))):
            # check if all FRs are below the FR=0 band
            if not ((np.all(mb_frate_tops[band] < -XTALK_FRATE)) and (np.all(mb_frate_bottoms[band] < -XTALK_FRATE))):
                return True
    return False

def subtract_polarized_models(data, flags, model_files, extra_flags=None):
    """
    Subtract phase-shifted models of polarized point sources from visibility data.

    For each model file, this function reads a model of a polarized source,
    computes the geometric phase shift needed to move the model from the
    phase center to the source's true sky position, and subtracts the
    phased model from the input data in-place. Visibilities are zeroed out
    wherever either the model or the data are flagged before subtraction.

    Parameters
    ----------
    data : HERAData
        Observed visibility data. Must contain antenna positions (`antpos`),
        frequencies (`freqs`), and observation times (`times`). Modified
        in-place: the model contribution is subtracted from each polarization.
    flags : dict
        Boolean flag dictionary keyed by (ant_i, ant_j, pol) tuples.
        True indicates a flagged (unusable) sample. Used in conjunction
        with model flags to zero out contributions before subtraction.
    model_files : list of str
        Paths to uvh5 model files, one per polarized source, per modeling component.
        Each file must contain:
          - A single baseline's worth of model visibilities
          - `SOURCE_RA` and `SOURCE_DEC` extra keywords (degrees)
            giving the ICRS sky position of the modeled source.

    Raises
    ------
    KeyError
        If a model file is missing the `SOURCE_RA` or `SOURCE_DEC` extra
        keywords, or if a polarization present in the model is absent from
        the data flags dictionary.
    """
    # Use the first baseline in the dataset to define the phasing baseline vector.
    # All model files are assumed to be for the same physical baseline geometry.
    ai, aj = bl = list(data.antpairs())[0]
    blvec = data.antpos[aj] - data.antpos[ai]  # East-North-Up baseline vector (meters)

    for model_file in model_files:
        hd_model = io.HERAData(model_file)
        model_data, model_flags, model_nsamples = hd_model.read()

        model_key = list(model_data.antpairs())[0]  # Baseline key stored in the model file
        model_pols = model_data.pols() 

        # Read the ICRS sky coordinates of the modeled source from file metadata
        right_ascension = hd_model.extra_keywords["SOURCE_RA"]   # degrees
        declination      = hd_model.extra_keywords["SOURCE_DEC"]  # degrees

        # Compute direction cosines (l, m, n) toward the source at each
        # observation time, accounting for Earth rotation and the telescope's
        # geographic location.
        lmn = polfilt.radec_to_lmn(
            right_ascension, declination, data.times, hd_model.telescope.location
        )  # shape: (n_times, 3)

        phasor = np.exp(
            2j * np.pi * np.dot(blvec, lmn)[:, np.newaxis] * hd_model.freqs / constants.c.value
        )  # shape: (n_times, n_freqs)

        for pol in model_pols:
            # Zero out samples where either the model or the data are flagged,
            # then apply the phase shift to move the model to the source position.
            flags_here = model_flags[model_key + (pol,)] | flags[bl + (pol,)]
            if extra_flags is not None:
                flags_here |= extra_flags
                
            bl_model = np.where(
                flags_here,
                0.0,
                model_data[model_key + (pol,)] * phasor
            )  # shape: (n_times, n_freqs)

            # Subtract the phased model from the data in-place
            data[bl + (pol,)] -= bl_model
In [9]:
def plot_fr_waterfall(snr_wf, flags_wf, taper_2d, freqs, times, title,
                      mb_frate_freqs_MHz=None, mb_frate_tops=None, mb_frate_bottoms=None,
                      sky_frate_freqs_MHz=None, sky_frate_tops=None, sky_frate_bottoms=None,
                      vmax=5):
    '''Plot freq vs fringe rate waterfall of |SNR| after FFT along time axis.
    Accepts pre-assembled full waterfalls with a 2D per-band taper.'''
    ntimes = len(times)
    times_in_seconds = (times - times[0]) * 24 * 3600
    frates = uvtools.utils.fourier_freqs(times_in_seconds) * 1000  # mHz

    # Per-column normalization accounting for taper and flags
    unflagged = (~flags_wf).astype(float)
    norm = (ntimes * np.mean((taper_2d * unflagged)**2, axis=0))**.5

    # FFT with per-band taper 
    to_plot = np.fft.fftshift(np.fft.fft(taper_2d * np.where((~flags_wf) & (taper_2d > 0), snr_wf, 0), axis=0), axes=0)
    to_plot = np.abs(to_plot) / norm[np.newaxis, :]

    fig = plt.figure(figsize=(14, 8), dpi=200)
    extent = [freqs[0] / 1e6, freqs[-1] / 1e6, frates[-1], frates[0]]
    im = plt.imshow(to_plot, aspect='auto', interpolation='none',
                    extent=extent, vmin=0, vmax=vmax, cmap='plasma')
    plt.colorbar(im, extend='max', label='|pI SNR|')
    plt.xlabel('Frequency (MHz)')
    plt.ylabel('Fringe Rate (mHz)')
    plt.title(title)

    if sky_frate_freqs_MHz is not None:
        plt.plot(sky_frate_freqs_MHz, sky_frate_tops, 'w:', lw=1, label='Sky FRs')
        plt.plot(sky_frate_freqs_MHz, sky_frate_bottoms, 'w:', lw=1)
        plt.ylim([-np.max([np.abs(sky_frate_tops), np.abs(sky_frate_bottoms)]) * 1.25,
                  np.max([np.abs(sky_frate_tops), np.abs(sky_frate_bottoms)]) * 1.25])
    else:
        plt.ylim([-5, 5])
    if mb_frate_freqs_MHz is not None:
        plt.plot(mb_frate_freqs_MHz, mb_frate_tops, 'w--', lw=1, label='Main Beam FRs')
        plt.plot(mb_frate_freqs_MHz, mb_frate_bottoms, 'w--', lw=1)
    if sky_frate_freqs_MHz is not None or mb_frate_freqs_MHz is not None:
        plt.legend()

    plt.tight_layout()
    plt.close(fig)
    return fig


def plot_time_freq_waterfall(snr_wf, flags_wf, freqs, times, lsts, title, vmax=5):
    '''Plot freq vs time waterfall of |SNR| in real space with LST right axis.
    Accepts pre-assembled full waterfalls.'''
    to_plot = np.where(flags_wf, np.nan, np.abs(snr_wf))

    fig, ax = plt.subplots(figsize=(14, 8), dpi=200)
    extent = [freqs[0] / 1e6, freqs[-1] / 1e6,
              times[-1] - int(times[0]), times[0] - int(times[0])]
    im = ax.imshow(to_plot, aspect='auto', interpolation='none',
                   extent=extent, vmin=0, vmax=vmax, cmap='plasma')
    plt.colorbar(im, extend='max', label='|pI SNR|', ax=ax)
    ax.set_xlabel('Frequency (MHz)')
    ax.set_ylabel(f'JD - {int(times[0])}')
    ax.set_title(title)

    # Add LST right axis with proper wrapping
    lst_grid = lsts * 12 / np.pi  # radians to hours
    lst_grid[lst_grid > lst_grid[-1]] -= 24
    ax2 = ax.twinx()
    ax2.set_ylim(lst_grid[-1], lst_grid[0])
    mod24 = lambda x, _: f"{x % 24:.1f}"
    ax2.yaxis.set_major_formatter(matplotlib.ticker.FuncFormatter(mod24))
    ax2.set_ylabel('LST (hours)')

    plt.tight_layout()
    plt.close(fig)
    return fig


def plot_snr_histograms(snr_wf, flags_wf, title):
    '''Plot histogram of |SNR| compared to the Rayleigh distribution expected for noise-only.
    Accepts pre-assembled full waterfall and flags.'''
    fig = plt.figure(figsize=(12, 5))
    bins = np.arange(0, 10, .01)
    to_hist = np.abs(snr_wf[~flags_wf])
    to_hist = to_hist[np.isfinite(to_hist) & (to_hist > 0)]
    hist = plt.hist(to_hist, bins=bins, density=True, label='Real-space |pI SNR|')
    plt.plot(bins, 2 * bins * np.exp(-bins**2), 'k--', label='Rayleigh Distribution (Noise-Only)')
    plt.yscale('log')
    all_densities = hist[0][hist[0] > 0]
    if len(all_densities) > 0:
        plt.ylim(np.min(all_densities) / 2, np.max(all_densities) * 2)
    plt.legend()
    plt.ylabel('Density')
    plt.xlabel('|pI SNR|')
    plt.xlim([-.5, 10])
    plt.title(title)
    plt.tight_layout()
    plt.close(fig)
    return fig

Compute pI SNR, Looping Over Baselines¶

In [10]:
delay_fr_figs = []
dly_waterfall_figs = []
dly_histogram_figs = []
waterfall_figs = []
histogram_figs = []

for single_bl_file in corner_turn_map['files_to_outfiles_map'][RED_AVG_FILE]:
    if not SAVE_DLY_SNR and not SAVE_FRF_SNR:
        continue

    # Load data
    single_bl_file = single_bl_file.replace('.uvh5', '.inpainted.uvh5')
    print(f'Now loading {single_bl_file}')
    hd = io.HERAData(single_bl_file)
    data, flags, nsamples = hd.read(polarizations=['ee', 'nn'])
    dt = np.median(np.diff(hd.times)) * 24 * 3600
    df = np.median(np.diff(hd.freqs))
    antpair = data.antpairs().pop()

    if antpair[0] == antpair[1]:
        print('\tThis baseline is an autocorrelation. Skipping...')
        continue

    med_auto_nsamples = {bl[2]: np.median(n) for bl, n in auto_nsamples.items()}
    if not any([np.median(nsamples[bl]) > MIN_SAMP_FRAC * med_auto_nsamples[bl[2]] for bl in nsamples]):
        print('\tNo polarization has enough nsamples to be worth filtering. Skipping...')
        continue

    # Combine flags across pols, including prior flags if enabled
    flags_here = flags[antpair + ('ee',)] | flags[antpair + ('nn',)]
    if APPLY_PRIOR_FLAGS:
        flags_here |= prior_flags

    # OR in per-baseline where_inpainted flags to make sure they proagate into full_day_rfi_round_4 and full_day_rfi_round_5
    if APPLY_WHERE_INPAINTED_FLAGS:
        wip_flags = np.zeros_like(flags_here)
        where_inpainted_file = single_bl_file.replace('.inpainted.uvh5', WHERE_INPAINTED_SUFFIX)
        if not os.path.exists(where_inpainted_file):
            raise FileNotFoundError(
                f'APPLY_WHERE_INPAINTED_FLAGS is True but no file found at {where_inpainted_file}')
        uvf_wi = UVFlag(where_inpainted_file)
        # uvf_wi.flag_array has shape (Ntimes, Nfreqs, Npols); collapse pol axis with np.any
        # to match the ee|nn combine semantics used for flags_here.
        wip_flags = np.any(uvf_wi.flag_array, axis=-1)
        if wip_flags.shape != flags_here.shape:
            raise ValueError(
                f'where_inpainted flags shape {wip_flags.shape} does not match '
                f'flags_here shape {flags_here.shape} for {where_inpainted_file}')
        flags_here |= wip_flags
        del uvf_wi

    # Get tslice and bands split around FM
    tslices_bl, bands_bl = flag_utils.get_minimal_slices(flags_here,
        freqs=data.freqs, freq_cuts=[(FM_LOW_FREQ + FM_HIGH_FREQ) * .5e6])

    if SUBTRACT_POLARIZED_SOURCE:
        print ("\tSubtracting off a polarized source model from the visibilities...")
        subtract_polarized_models(data, flags, model_files, extra_flags=flags_here)

    # Compute FR ranges for this baseline
    print('\tComputing fringe rate ranges...')
    (mb_frate_freqs_MHz, mb_frate_tops, mb_frate_bottoms) = compute_mb_fr_ranges(hd, antpair)
    (sky_frate_freqs_MHz, sky_frate_tops, sky_frate_bottoms) = compute_sky_fr_ranges(hd, antpair)

    # Decide whether to process baselines whose main beam overlaps FR = 0 ± XTALK_FRATE.
    # By default these are skipped, since they often contain residual crosstalk-like
    # structure at FR=0 that's probably not RFI. To process them anyway, set
    # SKIP_FR0_OVERLAP_BASELINES=FALSE -- the FRF will then apply a per-channel FR=0
    # notch (and a leverage correction that accounts for it) on overlap channels.
    fr0_overlap = overlaps_FR0(bands_bl, mb_frate_tops, mb_frate_bottoms)
    if fr0_overlap and SKIP_FR0_OVERLAP_BASELINES:
        print(f'\tThis baseline overlaps FR = 0 ± {XTALK_FRATE} mHz. Skipping. '
              f'Set SKIP_FR0_OVERLAP_BASELINES=FALSE to process these with a per-channel FR=0 notch in the FRF step.')
        continue
    if fr0_overlap and SAVE_FRF_SNR:
        print(f'\tThis baseline overlaps FR = 0 ± {XTALK_FRATE} mHz. The FRF will apply a per-channel FR=0 notch.')
    elif fr0_overlap:
        print(f'\tWARNING: this baseline overlaps FR = 0 ± {XTALK_FRATE} mHz, but SAVE_FRF_SNR is False -- '
               'no FR=0 notch will be applied (as it is part of the FRF step).')

    # Process each band
    filt_flags_full = np.ones((len(hd.times), len(hd.freqs)), dtype=bool)
    newly_flagged = np.zeros((len(hd.times), len(hd.freqs)), dtype=bool)
    dly_filt_SNR_full = np.full((len(hd.times), len(hd.freqs)), np.nan, dtype=complex)
    frf_SNR_full = np.full((len(hd.times), len(hd.freqs)), np.nan, dtype=complex)
    taper_2d = np.zeros((len(hd.times), len(hd.freqs)))

    for tslice, band in zip(tslices_bl, bands_bl):
        if (band is None) or np.all(flags_here[tslice, band]):
            continue

        # Extract per-pol data for this band
        d_ee = data[antpair + ('ee',)][tslice, band]
        d_nn = data[antpair + ('nn',)][tslice, band]
        f_ee = flags[antpair + ('ee',)][tslice, band]
        f_nn = flags[antpair + ('nn',)][tslice, band]
        n_ee = nsamples[antpair + ('ee',)][tslice, band]
        n_nn = nsamples[antpair + ('nn',)][tslice, band]
        a_ee = np.abs(autos[autos.antpairs().pop() + ('ee',)][tslice, band])
        a_nn = np.abs(autos[autos.antpairs().pop() + ('nn',)][tslice, band])

        # Compute variance from autos
        var_pI = a_ee**2 / (dt * df) / n_ee + a_nn**2 / (dt * df) / n_nn

        # Form pseudo-Stokes pI
        d_pI, f_pI, n_pI = hp.pstokes._combine_pol_arrays(
            'ee', 'nn', 'pI', pol_convention=hd.pol_convention,
            data_list=[d_ee, d_nn], flags_list=[f_ee, f_nn],
            nsamples_list=[n_ee, n_nn],
            x_orientation=hd.telescope.get_x_orientation_from_feeds())

        if APPLY_PRIOR_FLAGS:
            f_pI |= prior_flags[tslice, band]
        if APPLY_WHERE_INPAINTED_FLAGS:
            f_pI |= wip_flags[tslice, band]

        d_pI[f_pI] = 0

        # Compute SNR
        SNR = d_pI / var_pI**.5

        # Delay filter using dpss_matrix (pinv-based) for stability across large frequency gaps.
        # dpss_leastsq and dpss_solve are faster but can produce unstable results when X^T W X
        # is ill-conditioned due to large contiguous gaps in the flagging.
        print(f'\tDelay filtering band {band}...')
        wgts = np.where(f_pI, 0, 1).astype(float)
        filter_kwargs = dict(filter_centers=[0], filter_half_widths=[FILTER_DELAY * 1e-9],
            eigenval_cutoff=[EIGENVAL_CUTOFF], suppression_factors=[EIGENVAL_CUTOFF],
            max_contiguous_edge_flags=len(hd.freqs), ridge_alpha=RIDGE_ALPHA)
        result, _, info = fourier_filter(hd.freqs[band], SNR, wgts=wgts,
            mode='dpss_matrix', **filter_kwargs)
        dly_filt_SNR = SNR - result

        # Per-integration weighted leverage correction
        X = dpss_operator(hd.freqs[band], [0],
            filter_half_widths=[FILTER_DELAY / 1e9], eigenval_cutoff=[EIGENVAL_CUTOFF])[0]
        correction_cache = {}
        for i in range(SNR.shape[0]):
            w = wgts[i]
            if not np.any(w):
                continue
            cache_key = f_pI[i].tobytes()
            if cache_key not in correction_cache:
                XtWX = (X.T * w) @ X
                XtWX += np.diag(np.diag(XtWX)) * RIDGE_ALPHA
                if np.all(np.isclose(XtWX.imag, 0)):
                    XtWX = np.real(XtWX)
                try:
                    P = np.linalg.pinv(XtWX, hermitian=True) @ (X.T * w)
                    lev = np.real(np.sum(X * P.T, axis=1))
                    lev = np.clip(lev, 0.0, LEVERAGE_CAP)
                    correction = (1 - lev)**.5
                    correction_cache[cache_key] = np.where(np.isfinite(correction) & (lev > 0) & (lev < 1), correction, np.nan)
                except np.linalg.LinAlgError:
                    correction_cache[cache_key] = np.full(X.shape[0], np.nan)
            dly_filt_SNR[i] /= correction_cache[cache_key]

        # update larger arrays and account for non-finite SNR (degenerate leverage)
        newly_flagged[tslice, band] |= ~np.isfinite(dly_filt_SNR) & ~f_pI
        f_pI |= newly_flagged[tslice, band]
        dly_filt_SNR[newly_flagged[tslice, band]] = 0
        filt_flags_full[tslice, band] = f_pI
        dly_filt_SNR_full[tslice, band] = dly_filt_SNR

        if SAVE_FRF_SNR:
            # The FR=0 notch and main-beam FRF act on a copy of the delay-filtered
            # data; dly_filt_SNR (and dly_filt_SNR_full above) are left untouched.
            times_in_seconds = (hd.times[tslice] - hd.times[tslice][0]) * 24 * 3600
            frf_SNR = copy.deepcopy(dly_filt_SNR)

            # ---- FR=0 notch on frf_SNR, per-channel (always, when SAVE_FRF_SNR) ----
            # The cache stores (correction_per_t, P_fr0); P_fr0 is reused below to
            # compute the leverage of the composed FR=0 + main-beam operator.
            print(f'\tFR=0 notch filtering band {band}...')
            Xt_fr0 = dpss_operator(hd.times[tslice] * 24 * 3600, filter_centers=[0],
                filter_half_widths=[XTALK_FRATE / 1000], eigenval_cutoff=[EIGENVAL_CUTOFF])[0]
            fr0_correction_cache = {}
            for chan, (fr_low, fr_high) in enumerate(zip(mb_frate_bottoms[band], mb_frate_tops[band])):
                w_fr0 = np.where(f_pI[:, chan], 0, 1).astype(float)
                if not np.any(w_fr0):
                    continue
                cache_key_fr0 = f_pI[:, chan].tobytes()
                if cache_key_fr0 not in fr0_correction_cache:
                    XtWXt_fr0 = np.dot(Xt_fr0.conj().T * w_fr0, Xt_fr0)
                    XtWXt_fr0 += np.diag(np.diag(XtWXt_fr0)) * RIDGE_ALPHA
                    try:
                        P_fr0 = np.linalg.pinv(XtWXt_fr0, hermitian=True) @ (Xt_fr0.conj().T * w_fr0)
                        lev_fr0 = np.real(np.sum(Xt_fr0 * P_fr0.T, axis=1))
                        lev_fr0 = np.clip(lev_fr0, 0.0, LEVERAGE_CAP)
                        correction_fr0 = (1 - lev_fr0)**.5
                        correction_fr0 = np.where(
                            np.isfinite(correction_fr0) & (lev_fr0 > 0) & (lev_fr0 < 1),
                            correction_fr0, np.nan)
                        fr0_correction_cache[cache_key_fr0] = (correction_fr0, P_fr0)
                    except np.linalg.LinAlgError:
                        fr0_correction_cache[cache_key_fr0] = (np.full(Xt_fr0.shape[0], np.nan), None)
                result_fr0, _, _ = fourier_filter(times_in_seconds,
                    frf_SNR[:, chan:chan+1],
                    wgts=np.where(f_pI[:, chan:chan+1], 0, 1),
                    filter_centers=[0], filter_half_widths=[XTALK_FRATE / 1000],
                    mode='dpss_leastsq', eigenval_cutoff=[EIGENVAL_CUTOFF],
                    suppression_factors=[EIGENVAL_CUTOFF],
                    max_contiguous_edge_flags=len(hd.times[tslice]),
                    filter_dims=0, ridge_alpha=RIDGE_ALPHA)
                frf_SNR[:, chan:chan+1] -= result_fr0
                frf_SNR[:, chan] /= fr0_correction_cache[cache_key_fr0][0]

            # The per-time divide by sqrt(1 - lev_fr0) above can introduce NaN at
            # degenerate-leverage times. Zero those before the main FR fit so they
            # don't poison fourier_filter's masked least-squares (NaN at mask=True
            # positions propagates through Xy = amat[mask].T @ y[mask]).
            newly_flagged[tslice, band] |= ~np.isfinite(frf_SNR) & ~f_pI
            f_pI |= newly_flagged[tslice, band]
            frf_SNR[newly_flagged[tslice, band]] = 0
            filt_flags_full[tslice, band] = f_pI

            # ---- Per-channel FR filter at the main-beam center ----
            print(f'\tFR filtering band {band}...')
            for chan, (fr_low, fr_high) in enumerate(zip(mb_frate_bottoms[band], mb_frate_tops[band])):
                fr_center = (fr_low + fr_high) / 2 / 1000  # mHz -> Hz
                fr_halfwidth = (fr_high - fr_low) / 2 / 1000  # mHz -> Hz

                # filter directly at the main-beam fringe rate center
                result, _, _ = fourier_filter(times_in_seconds, frf_SNR[:, chan:chan+1],
                    wgts=np.where(f_pI[:, chan:chan+1], 0, 1), filter_centers=[fr_center],
                    filter_half_widths=[fr_halfwidth], mode='dpss_leastsq',
                    eigenval_cutoff=[EIGENVAL_CUTOFF], suppression_factors=[EIGENVAL_CUTOFF],
                    max_contiguous_edge_flags=len(data.times), filter_dims=0)
                Xt = dpss_operator(hd.times[tslice] * 24 * 3600, filter_centers=[fr_center],
                    filter_half_widths=[fr_halfwidth], eigenval_cutoff=[EIGENVAL_CUTOFF])[0]
                W = np.where(f_pI[:, chan], 0, 1).astype(float)
                XtWXt = np.dot(Xt.conj().T * W, Xt)
                XtWXt += np.diag(np.diag(XtWXt)) * RIDGE_ALPHA
                try:
                    P = np.linalg.pinv(XtWXt, hermitian=True) @ (Xt.conj().T * W)

                    # Per-time noise variance after the composed operator
                    #   M = H_main @ D_fr0 @ (I - H_fr0)
                    # acts on raw unit-variance noise. Factored form:
                    #   Q = P*D_fr0  -  (P*D_fr0 @ Xt_fr0) @ P_fr0
                    #   var_t = (Xt @ S @ Xt^H)_tt,  S = Q @ Q^H
                    # avoids ever materialising the n_t x n_t hat matrices.
                    cache_entry = fr0_correction_cache.get(f_pI[:, chan].tobytes())
                    if cache_entry is not None and cache_entry[1] is not None:
                        correction_fr0, P_fr0 = cache_entry
                        D_fr0 = np.where(np.isfinite(correction_fr0), 1.0 / correction_fr0, 0.0)
                        PD = P * D_fr0[None, :]
                        Q = PD - (PD @ Xt_fr0) @ P_fr0
                    else:
                        Q = P
                    S = Q @ Q.conj().T
                    var_t = np.real(np.einsum('tb,bc,tc->t', Xt, S, Xt.conj()))
                    lev_t_correction = np.where(var_t > 0, np.sqrt(var_t), np.nan)
                except np.linalg.LinAlgError:
                    lev_t_correction = np.full(Xt.shape[0], np.nan)

                frf_SNR[:, chan:chan+1] = np.where(f_pI[:, chan:chan+1], 0, result / lev_t_correction[:, None])

            # account for non-finite SNR again
            frf_SNR_full[tslice, band] = frf_SNR
            newly_flagged[tslice, band] |= ~np.isfinite(frf_SNR) & ~f_pI

        # build 2D taper for plotting
        band_ntimes = tslice.stop - tslice.start
        taper_2d[tslice, band] = blackmanharris(band_ntimes)[:, np.newaxis]

    if np.any(newly_flagged):
        print(f'\tFlagging {np.sum(newly_flagged)} waterfall pixels due to NaNs and infs, typically due to degenerate leverage.')
        filt_flags_full |= newly_flagged
    
    if np.all(filt_flags_full):
        print(f'\t{antpair} is entirely flagged.')
        continue

    # Now produce figures to display later
    delay_fr_figs.append(plot_fr_waterfall(
        dly_filt_SNR_full, filt_flags_full, taper_2d, hd.freqs, hd.times,
        f'{antpair} Delay-Filtered pI',
        mb_frate_freqs_MHz=mb_frate_freqs_MHz, mb_frate_tops=mb_frate_tops, mb_frate_bottoms=mb_frate_bottoms,
        sky_frate_freqs_MHz=sky_frate_freqs_MHz, sky_frate_tops=sky_frate_tops, sky_frate_bottoms=sky_frate_bottoms))

    dly_waterfall_figs.append(plot_time_freq_waterfall(
        dly_filt_SNR_full, filt_flags_full, hd.freqs, hd.times, hd.lsts,
        f'{antpair} Delay-Filtered pI'))

    dly_histogram_figs.append(plot_snr_histograms(
        dly_filt_SNR_full, filt_flags_full, f'{antpair} Delay-Filtered pI'))

    if SAVE_FRF_SNR:
        waterfall_figs.append(plot_time_freq_waterfall(
            frf_SNR_full, filt_flags_full, hd.freqs, hd.times, hd.lsts,
            f'{antpair} Delay+FR-Filtered pI (+ FR=0 notch)'))
        histogram_figs.append(plot_snr_histograms(
            frf_SNR_full, filt_flags_full, f'{antpair} Delay+FR-Filtered pI (+ FR=0 notch)'))

    # Write SNR outputs
    if SAVE_DLY_SNR or SAVE_FRF_SNR:
        hd.select(polarizations=['ee'])
        hd.polarization_array[0] = utils.polstr2num('pI')
        for bl in list(data.keys()):
            if bl[2] != 'ee':
                del data[bl]
                del flags[bl]
        for to_save, snr_full, suffix, label in [(SAVE_DLY_SNR, dly_filt_SNR_full, DLY_SNR_SUFFIX, 'delay-filtered'),
                                                 (SAVE_FRF_SNR, frf_SNR_full, FRF_SNR_SUFFIX, 'delay-filtered and FRFed')]:
            if not to_save:
                continue
            data[antpair + ('ee',)] = np.where(np.isfinite(snr_full), snr_full, 0)
            flags[antpair + ('ee',)] = filt_flags_full
            hd.update(data=data, flags=flags)
            outfile = single_bl_file.replace('.uvh5', suffix)
            print(f'\tWriting {label} results to {outfile}')
            hd.write_uvh5(outfile, clobber=True)
Now loading /lustre/aoc/projects/hera/h6c-analysis/IDR3_2/2459917/single_baseline_files/zen.2459917.baseline.0_4.sum.smooth_calibrated.red_avg.inpainted.uvh5
	Subtracting off a polarized source model from the visibilities...
	Computing fringe rate ranges...
	Delay filtering band slice(np.int64(0), np.int64(333), None)...
invalid value encountered in divide
	Delay filtering band slice(np.int64(502), np.int64(1493), None)...
invalid value encountered in divide
	Writing delay-filtered results to /lustre/aoc/projects/hera/h6c-analysis/IDR3_2/2459917/single_baseline_files/zen.2459917.baseline.0_4.sum.smooth_calibrated.red_avg.inpainted.pI_DLYFILT_SNR.uvh5

Figure 1: Delay-Filtered pI SNR in Fringe Rate Space¶

In [11]:
for fig in delay_fr_figs:
    display(fig)
No description has been provided for this image

Figure 2: Delay-Filtered pI SNR Waterfall¶

In [12]:
for fig in dly_waterfall_figs:
    display(fig)
No description has been provided for this image

Figure 3: Delay-Filtered pI SNR Histogram¶

In [13]:
for fig in dly_histogram_figs:
    display(fig)
No description has been provided for this image

Figure 4: Delay+FR Filtered pI SNR Waterfall¶

In [14]:
for fig in waterfall_figs:
    display(fig)

Figure 5: Delay+FR Filtered pI SNR Histogram¶

In [15]:
for fig in histogram_figs:
    display(fig)

Metadata¶

In [16]:
for repo in ['hera_cal', 'hera_qm', 'hera_filters', 'hera_notebook_templates', 'hera_pspec', 'pyuvdata', 'numpy']:
    exec(f'from {repo} import __version__')
    print(f'{repo}: {__version__}')
hera_cal: 3.7.8.dev62+gf11268d54
hera_qm: 2.2.1.dev12+g95ecc30f0
hera_filters: 0.1.9.dev3+ged4deb46a
hera_notebook_templates: 0.0.1.dev1457+gb078d937f
hera_pspec: 0.4.3.dev99+g0e4b0e22b
pyuvdata: 3.2.5
numpy: 2.3.5
In [17]:
print(f'Finished execution in {(time.time() - tstart) / 60:.2f} minutes.')
Finished execution in 1.39 minutes.