Single Baseline pI FRF SNR¶
by Josh Dillon and Tyler Cox, last updated May 26, 2026
This notebook takes corner-turned, calibrated, redundantly-averaged visibility data, forms pseudo-Stokes pI, and computes delay-filtered, fringe-rate-filtered SNR waterfalls. The results are written out as uvh5 files to be combined across baselines to look for residual structure that fringes like the main beam.
Here's a set of links to skip to particular figures and tables:
• Figure 1: Delay-Filtered pI SNR in Fringe Rate Space¶
• Figure 2: Delay-Filtered pI SNR Waterfall¶
• Figure 3: Delay-Filtered pI SNR Histogram¶
• Figure 4: Delay+FR Filtered pI SNR Waterfall¶
• Figure 5: Delay+FR Filtered pI SNR Histogram¶
In [1]:
import time
tstart = time.time()
!hostname
!date
herapost010
Sat May 30 13:11:16 MDT 2026
In [2]:
import os
os.environ['HDF5_USE_FILE_LOCKING'] = 'FALSE'
import h5py
import hdf5plugin
import numpy as np
import yaml
import copy
import re
import glob
import matplotlib
from astropy import units, constants
from scipy import interpolate
from scipy.signal.windows import blackmanharris
from pyuvdata import UVFlag
from hera_cal import io, utils, flag_utils, red_groups, polfilt
from hera_cal.frf import sky_frates_single, get_FR_buffer_from_spectra, get_m2f_mixer
from hera_filters.dspec import dpss_operator, fourier_filter
import hera_pspec as hp
import uvtools
import matplotlib.pyplot as plt
from IPython.display import display
%matplotlib inline
In [3]:
RED_AVG_FILE = os.environ.get("RED_AVG_FILE", None)
# RED_AVG_FILE = "/lustre/aoc/projects/hera/h6c-analysis/IDR3_2/2459935/zen.2459935.21408.sum.smooth_calibrated.red_avg.uvh5" # TODO
CORNER_TURN_MAP_YAML = os.environ.get("CORNER_TURN_MAP_YAML",
os.path.join(os.path.dirname(RED_AVG_FILE), "single_baseline_files/corner_turn_map.yaml"))
FRF_SNR_SUFFIX = os.environ.get("FRF_SNR_SUFFIX", ".pI_FRF_SNR.uvh5")
SAVE_DLY_SNR = os.environ.get("SAVE_DLY_SNR", "TRUE").upper() == "TRUE"
DLY_SNR_SUFFIX = os.environ.get("DLY_SNR_SUFFIX", ".pI_DLYFILT_SNR.uvh5")
SAVE_FRF_SNR = os.environ.get("SAVE_FRF_SNR", "TRUE").upper() == "TRUE"
FM_LOW_FREQ = float(os.environ.get("FM_LOW_FREQ", 87.5)) # in MHz
FM_HIGH_FREQ = float(os.environ.get("FM_HIGH_FREQ", 108.0)) # in MHz
FILTER_DELAY = float(os.environ.get("FILTER_DELAY", 750)) # in ns
EIGENVAL_CUTOFF = float(os.environ.get("EIGENVAL_CUTOFF", 1e-12))
FR_SPECTRA_FILE = os.environ.get("FR_SPECTRA_FILE",
"/lustre/aoc/projects/hera/h6c-analysis/IDR3/beam_simulation_products/spectra_cache_hera_core.h5")
AUTO_FR_SPECTRUM_FILE = os.environ.get("AUTO_FR_SPECTRUM_FILE",
"/lustre/aoc/projects/hera/zmartino/hera_frf/spectra_cache/spectra_cache_hera_auto.h5")
XTALK_FRATE = float(os.environ.get("XTALK_FRATE", 0.01)) # in mHz
SKIP_FR0_OVERLAP_BASELINES = os.environ.get("SKIP_FR0_OVERLAP_BASELINES", "FALSE").upper() == "TRUE"
FR_QUANTILE_LOW = float(os.environ.get("FR_QUANTILE_LOW", 0.05))
FR_QUANTILE_HIGH = float(os.environ.get("FR_QUANTILE_HIGH", 0.95))
MIN_SAMP_FRAC = float(os.environ.get("MIN_SAMP_FRAC", 0.1))
RIDGE_ALPHA = float(os.environ.get("RIDGE_ALPHA", 1e-12))
LEVERAGE_CAP = float(os.environ.get("LEVERAGE_CAP", 0.999))
APPLY_PRIOR_FLAGS = os.environ.get("APPLY_PRIOR_FLAGS", "TRUE").upper() == "TRUE"
PRIOR_FLAG_SUFFIX = os.environ.get("PRIOR_FLAG_SUFFIX", ".flag_waterfall_round_3.h5")
APPLY_WHERE_INPAINTED_FLAGS = os.environ.get("APPLY_WHERE_INPAINTED_FLAGS", "TRUE").upper() == "TRUE"
WHERE_INPAINTED_SUFFIX = os.environ.get("WHERE_INPAINTED_SUFFIX", ".where_inpainted.h5")
SUBTRACT_POLARIZED_SOURCE = os.environ.get("SUBTRACT_POLARIZED_SOURCE", "FALSE").upper() == "TRUE"
SOURCE_YAML = os.environ.get(
"SOURCE_YAML", "/lustre/aoc/projects/hera/h6c-analysis/IDR3_2/metadata/sources.yaml"
)
for setting in ['RED_AVG_FILE', 'CORNER_TURN_MAP_YAML', 'FRF_SNR_SUFFIX', 'DLY_SNR_SUFFIX',
'FR_SPECTRA_FILE', 'AUTO_FR_SPECTRUM_FILE', 'PRIOR_FLAG_SUFFIX', 'WHERE_INPAINTED_SUFFIX']:
print(f'{setting} = "{eval(setting)}"')
for setting in []:
print(f'{setting} = {eval(setting)}')
for setting in ['FM_LOW_FREQ', 'FM_HIGH_FREQ', 'FILTER_DELAY', 'EIGENVAL_CUTOFF', 'XTALK_FRATE',
'SAVE_DLY_SNR', 'SAVE_FRF_SNR', 'FR_QUANTILE_LOW', 'FR_QUANTILE_HIGH', 'MIN_SAMP_FRAC',
'APPLY_PRIOR_FLAGS', 'APPLY_WHERE_INPAINTED_FLAGS', 'SKIP_FR0_OVERLAP_BASELINES']:
print(f'{setting} = {eval(setting)}')
RED_AVG_FILE = "/lustre/aoc/projects/hera/h6c-analysis/IDR3_2/2459920/zen.2459920.25326.sum.smooth_calibrated.red_avg.uvh5" CORNER_TURN_MAP_YAML = "/lustre/aoc/projects/hera/h6c-analysis/IDR3_2/2459920/single_baseline_files/corner_turn_map.yaml" FRF_SNR_SUFFIX = ".pI_FRF_SNR.uvh5" DLY_SNR_SUFFIX = ".pI_DLYFILT_SNR.uvh5" FR_SPECTRA_FILE = "/lustre/aoc/projects/hera/h6c-analysis/IDR3/beam_simulation_products/spectra_cache_hera_core.h5" AUTO_FR_SPECTRUM_FILE = "/lustre/aoc/projects/hera/zmartino/hera_frf/spectra_cache/spectra_cache_hera_auto.h5" PRIOR_FLAG_SUFFIX = ".flag_waterfall_round_3.h5" WHERE_INPAINTED_SUFFIX = ".where_inpainted.h5" FM_LOW_FREQ = 87.5 FM_HIGH_FREQ = 108.0 FILTER_DELAY = 750.0 EIGENVAL_CUTOFF = 1e-12 XTALK_FRATE = 0.01 SAVE_DLY_SNR = True SAVE_FRF_SNR = False FR_QUANTILE_LOW = 0.05 FR_QUANTILE_HIGH = 0.95 MIN_SAMP_FRAC = 0.1 APPLY_PRIOR_FLAGS = True APPLY_WHERE_INPAINTED_FLAGS = True SKIP_FR0_OVERLAP_BASELINES = False
Preliminaries¶
In [4]:
with open(CORNER_TURN_MAP_YAML, 'r') as file:
corner_turn_map = yaml.unsafe_load(file)
In [5]:
# get autocorrelations
# TODO: generalize for not-previously inpainted data
all_outfiles = [outfile.replace('.uvh5', '.inpainted.uvh5') for outfiles in corner_turn_map['files_to_outfiles_map'].values() for outfile in outfiles]
for outfile in all_outfiles:
match = re.search(r'\.(\d+)_(\d+)\.', os.path.basename(outfile))
if match and match.group(1) == match.group(2):
hd_autos = io.HERAData(outfile)
autos, _, auto_nsamples = hd_autos.read(polarizations=['ee', 'nn'])
break
In [6]:
# Load and combine prior flags if enabled
if APPLY_PRIOR_FLAGS:
jdstr = [s for s in os.path.basename(RED_AVG_FILE).split('.') if s.isnumeric()][0]
flag_dir = os.path.dirname(CORNER_TURN_MAP_YAML)
flag_pattern = os.path.join(flag_dir, f'zen.{jdstr}*{PRIOR_FLAG_SUFFIX}')
prior_flag_files = sorted(glob.glob(flag_pattern))
if len(prior_flag_files) == 0:
raise ValueError(f'APPLY_PRIOR_FLAGS is True but no files matched {flag_pattern}')
else:
print(f'Found {len(prior_flag_files)} prior flag files:')
for f in prior_flag_files:
print(f' {os.path.basename(f)}')
prior_flags = np.any([np.all(UVFlag(flag_file).flag_array, axis=-1) for flag_file in prior_flag_files], axis=0)
print(f'Combined prior flags: {np.mean(prior_flags):.3%} flagged.')
Found 1 prior flag files: zen.2459920.flag_waterfall_round_3.h5
Combined prior flags: 28.754% flagged.
In [7]:
if SUBTRACT_POLARIZED_SOURCE:
with open(SOURCE_YAML) as f:
source_config = yaml.safe_load(f)
model_dir = os.path.dirname(CORNER_TURN_MAP_YAML) # directory containing the input data
model_files = []
for source in source_config["sources"]:
fsname = source["name"].replace(" ", "_").replace("-", "_")
model_files.extend(
model_file for model_name in ("fg_model", "rm_model", "scint_model")
if os.path.exists(
model_file := os.path.join(model_dir, f"phased.{jdstr}.{fsname}.{model_name}.uvh5")
)
)
print(f"Found {len(model_files)} model files")
for model_file in model_files:
print(f" {os.path.basename(model_file)}")
Found 6 model files phased.2459920.PSR_J0628_28.fg_model.uvh5 phased.2459920.PSR_J0628_28.rm_model.uvh5 phased.2459920.PSR_J0628_28.scint_model.uvh5 phased.2459920.PSR_J0742_2822.fg_model.uvh5 phased.2459920.PSR_J0742_2822.rm_model.uvh5 phased.2459920.PSR_J0742_2822.scint_model.uvh5
Define Plotting and Helper Functions¶
In [8]:
def compute_mb_fr_ranges(hd, antpair):
'''Compute main beam fringe rate ranges from beam simulation spectra.
Returns per-frequency arrays: freqs in MHz, FR bounds in mHz.'''
with h5py.File(FR_SPECTRA_FILE, "r") as h5f:
metadata = h5f["metadata"]
bl_to_index_map = {tuple(ap): int(index) for index, antpairs
in metadata["baseline_groups"].items() for ap in antpairs}
spectrum_freqs = metadata["frequencies_MHz"][()] * 1e6
m_modes = metadata["erh_mode_integer_index"][()]
this_red_group = red_groups.RedundantGroups.from_antpos(hd.antpos)[antpair]
for ap in this_red_group:
if ap in bl_to_index_map:
mmode_spectrum = h5f["erh_mode_power_spectrum"][:, :, bl_to_index_map[ap]]
break
elif ap[::-1] in bl_to_index_map:
mmode_spectrum = h5f["erh_mode_power_spectrum"][:, :, bl_to_index_map[ap[::-1]]]
m_modes = m_modes * -1
break
else:
raise KeyError(f'{antpair}, nor any baseline redundant with it, was found in bl_to_index_map.')
# Build mixing matrix from m-modes to fringe rates
full_times = hd.times
times_ks = (full_times - full_times[0] + np.median(np.diff(full_times))) * units.day.to(units.ks)
filt_frates = np.fft.fftshift(np.fft.fftfreq(times_ks.size, d=np.median(np.diff(times_ks))))
_m2f_mixer = get_m2f_mixer(times_ks, m_modes)
# Vectorized: FR spectrum for every spectrum_freq channel at once
# _m2f_mixer: (n_frates, n_mmodes), mmode_spectrum: (n_mmodes, n_spectrum_freqs)
fr_spectra = np.abs(np.einsum("fm,mc,mf->fc", _m2f_mixer, mmode_spectrum, _m2f_mixer.T.conj()))
# Normalize each channel
fr_spectra /= fr_spectra.sum(axis=0, keepdims=True)
# Compute quantile bounds per spectrum_freq channel
cumsum = np.cumsum(fr_spectra, axis=0)
spec_tops = np.array([np.interp(FR_QUANTILE_HIGH, cumsum[:, c], filt_frates)
for c in range(len(spectrum_freqs))])
spec_bottoms = np.array([np.interp(FR_QUANTILE_LOW, cumsum[:, c], filt_frates)
for c in range(len(spectrum_freqs))])
# Interpolate from spectrum_freqs to data freqs (with extrapolation)
mb_frate_tops = interpolate.interp1d(spectrum_freqs, spec_tops, fill_value='extrapolate')(hd.freqs)
mb_frate_bottoms = interpolate.interp1d(spectrum_freqs, spec_bottoms, fill_value='extrapolate')(hd.freqs)
return (hd.freqs / 1e6, mb_frate_tops, mb_frate_bottoms)
def compute_sky_fr_ranges(hd, antpair):
'''Compute sky fringe rate ranges using analytic sky_frates_single + empirical buffer.
Returns per-frequency arrays in MHz and mHz.'''
blvec = hd.antpos[antpair[0]] - hd.antpos[antpair[1]]
latitude = hd.telescope.location.lat.rad
sky_centers, sky_hws = sky_frates_single(hd.freqs, blvec, latitude) # mHz
fr_buffer = get_FR_buffer_from_spectra(AUTO_FR_SPECTRUM_FILE, hd.times, hd.freqs, gauss_fit_buffer_cut=1e-5)
return (hd.freqs / 1e6, sky_centers + sky_hws + fr_buffer, sky_centers - sky_hws - fr_buffer)
def overlaps_FR0(bands_bl, mb_frate_tops, mb_frate_bottoms):
'''Check whether any band's main beam fringe rates overlap FR=0 ± XTALK_FRATE.'''
for band in bands_bl:
if band is None:
continue
# check if all FRs are above the FR=0 band
if not ((np.all(mb_frate_tops[band] > XTALK_FRATE)) and (np.all(mb_frate_bottoms[band] > XTALK_FRATE))):
# check if all FRs are below the FR=0 band
if not ((np.all(mb_frate_tops[band] < -XTALK_FRATE)) and (np.all(mb_frate_bottoms[band] < -XTALK_FRATE))):
return True
return False
def subtract_polarized_models(data, flags, model_files, extra_flags=None):
"""
Subtract phase-shifted models of polarized point sources from visibility data.
For each model file, this function reads a model of a polarized source,
computes the geometric phase shift needed to move the model from the
phase center to the source's true sky position, and subtracts the
phased model from the input data in-place. Visibilities are zeroed out
wherever either the model or the data are flagged before subtraction.
Parameters
----------
data : HERAData
Observed visibility data. Must contain antenna positions (`antpos`),
frequencies (`freqs`), and observation times (`times`). Modified
in-place: the model contribution is subtracted from each polarization.
flags : dict
Boolean flag dictionary keyed by (ant_i, ant_j, pol) tuples.
True indicates a flagged (unusable) sample. Used in conjunction
with model flags to zero out contributions before subtraction.
model_files : list of str
Paths to uvh5 model files, one per polarized source, per modeling component.
Each file must contain:
- A single baseline's worth of model visibilities
- `SOURCE_RA` and `SOURCE_DEC` extra keywords (degrees)
giving the ICRS sky position of the modeled source.
Raises
------
KeyError
If a model file is missing the `SOURCE_RA` or `SOURCE_DEC` extra
keywords, or if a polarization present in the model is absent from
the data flags dictionary.
"""
# Use the first baseline in the dataset to define the phasing baseline vector.
# All model files are assumed to be for the same physical baseline geometry.
ai, aj = bl = list(data.antpairs())[0]
blvec = data.antpos[aj] - data.antpos[ai] # East-North-Up baseline vector (meters)
for model_file in model_files:
hd_model = io.HERAData(model_file)
model_data, model_flags, model_nsamples = hd_model.read()
model_key = list(model_data.antpairs())[0] # Baseline key stored in the model file
model_pols = model_data.pols()
# Read the ICRS sky coordinates of the modeled source from file metadata
right_ascension = hd_model.extra_keywords["SOURCE_RA"] # degrees
declination = hd_model.extra_keywords["SOURCE_DEC"] # degrees
# Compute direction cosines (l, m, n) toward the source at each
# observation time, accounting for Earth rotation and the telescope's
# geographic location.
lmn = polfilt.radec_to_lmn(
right_ascension, declination, data.times, hd_model.telescope.location
) # shape: (n_times, 3)
phasor = np.exp(
2j * np.pi * np.dot(blvec, lmn)[:, np.newaxis] * hd_model.freqs / constants.c.value
) # shape: (n_times, n_freqs)
for pol in model_pols:
# Zero out samples where either the model or the data are flagged,
# then apply the phase shift to move the model to the source position.
flags_here = model_flags[model_key + (pol,)] | flags[bl + (pol,)]
if extra_flags is not None:
flags_here |= extra_flags
bl_model = np.where(
flags_here,
0.0,
model_data[model_key + (pol,)] * phasor
) # shape: (n_times, n_freqs)
# Subtract the phased model from the data in-place
data[bl + (pol,)] -= bl_model
In [9]:
def plot_fr_waterfall(snr_wf, flags_wf, taper_2d, freqs, times, title,
mb_frate_freqs_MHz=None, mb_frate_tops=None, mb_frate_bottoms=None,
sky_frate_freqs_MHz=None, sky_frate_tops=None, sky_frate_bottoms=None,
vmax=5):
'''Plot freq vs fringe rate waterfall of |SNR| after FFT along time axis.
Accepts pre-assembled full waterfalls with a 2D per-band taper.'''
ntimes = len(times)
times_in_seconds = (times - times[0]) * 24 * 3600
frates = uvtools.utils.fourier_freqs(times_in_seconds) * 1000 # mHz
# Per-column normalization accounting for taper and flags
unflagged = (~flags_wf).astype(float)
norm = (ntimes * np.mean((taper_2d * unflagged)**2, axis=0))**.5
# FFT with per-band taper
to_plot = np.fft.fftshift(np.fft.fft(taper_2d * np.where((~flags_wf) & (taper_2d > 0), snr_wf, 0), axis=0), axes=0)
to_plot = np.abs(to_plot) / norm[np.newaxis, :]
fig = plt.figure(figsize=(14, 8), dpi=200)
extent = [freqs[0] / 1e6, freqs[-1] / 1e6, frates[-1], frates[0]]
im = plt.imshow(to_plot, aspect='auto', interpolation='none',
extent=extent, vmin=0, vmax=vmax, cmap='plasma')
plt.colorbar(im, extend='max', label='|pI SNR|')
plt.xlabel('Frequency (MHz)')
plt.ylabel('Fringe Rate (mHz)')
plt.title(title)
if sky_frate_freqs_MHz is not None:
plt.plot(sky_frate_freqs_MHz, sky_frate_tops, 'w:', lw=1, label='Sky FRs')
plt.plot(sky_frate_freqs_MHz, sky_frate_bottoms, 'w:', lw=1)
plt.ylim([-np.max([np.abs(sky_frate_tops), np.abs(sky_frate_bottoms)]) * 1.25,
np.max([np.abs(sky_frate_tops), np.abs(sky_frate_bottoms)]) * 1.25])
else:
plt.ylim([-5, 5])
if mb_frate_freqs_MHz is not None:
plt.plot(mb_frate_freqs_MHz, mb_frate_tops, 'w--', lw=1, label='Main Beam FRs')
plt.plot(mb_frate_freqs_MHz, mb_frate_bottoms, 'w--', lw=1)
if sky_frate_freqs_MHz is not None or mb_frate_freqs_MHz is not None:
plt.legend()
plt.tight_layout()
plt.close(fig)
return fig
def plot_time_freq_waterfall(snr_wf, flags_wf, freqs, times, lsts, title, vmax=5):
'''Plot freq vs time waterfall of |SNR| in real space with LST right axis.
Accepts pre-assembled full waterfalls.'''
to_plot = np.where(flags_wf, np.nan, np.abs(snr_wf))
fig, ax = plt.subplots(figsize=(14, 8), dpi=200)
extent = [freqs[0] / 1e6, freqs[-1] / 1e6,
times[-1] - int(times[0]), times[0] - int(times[0])]
im = ax.imshow(to_plot, aspect='auto', interpolation='none',
extent=extent, vmin=0, vmax=vmax, cmap='plasma')
plt.colorbar(im, extend='max', label='|pI SNR|', ax=ax)
ax.set_xlabel('Frequency (MHz)')
ax.set_ylabel(f'JD - {int(times[0])}')
ax.set_title(title)
# Add LST right axis with proper wrapping
lst_grid = lsts * 12 / np.pi # radians to hours
lst_grid[lst_grid > lst_grid[-1]] -= 24
ax2 = ax.twinx()
ax2.set_ylim(lst_grid[-1], lst_grid[0])
mod24 = lambda x, _: f"{x % 24:.1f}"
ax2.yaxis.set_major_formatter(matplotlib.ticker.FuncFormatter(mod24))
ax2.set_ylabel('LST (hours)')
plt.tight_layout()
plt.close(fig)
return fig
def plot_snr_histograms(snr_wf, flags_wf, title):
'''Plot histogram of |SNR| compared to the Rayleigh distribution expected for noise-only.
Accepts pre-assembled full waterfall and flags.'''
fig = plt.figure(figsize=(12, 5))
bins = np.arange(0, 10, .01)
to_hist = np.abs(snr_wf[~flags_wf])
to_hist = to_hist[np.isfinite(to_hist) & (to_hist > 0)]
hist = plt.hist(to_hist, bins=bins, density=True, label='Real-space |pI SNR|')
plt.plot(bins, 2 * bins * np.exp(-bins**2), 'k--', label='Rayleigh Distribution (Noise-Only)')
plt.yscale('log')
all_densities = hist[0][hist[0] > 0]
if len(all_densities) > 0:
plt.ylim(np.min(all_densities) / 2, np.max(all_densities) * 2)
plt.legend()
plt.ylabel('Density')
plt.xlabel('|pI SNR|')
plt.xlim([-.5, 10])
plt.title(title)
plt.tight_layout()
plt.close(fig)
return fig
Compute pI SNR, Looping Over Baselines¶
In [10]:
delay_fr_figs = []
dly_waterfall_figs = []
dly_histogram_figs = []
waterfall_figs = []
histogram_figs = []
for single_bl_file in corner_turn_map['files_to_outfiles_map'][RED_AVG_FILE]:
if not SAVE_DLY_SNR and not SAVE_FRF_SNR:
continue
# Load data
single_bl_file = single_bl_file.replace('.uvh5', '.inpainted.uvh5')
print(f'Now loading {single_bl_file}')
hd = io.HERAData(single_bl_file)
data, flags, nsamples = hd.read(polarizations=['ee', 'nn'])
dt = np.median(np.diff(hd.times)) * 24 * 3600
df = np.median(np.diff(hd.freqs))
antpair = data.antpairs().pop()
if antpair[0] == antpair[1]:
print('\tThis baseline is an autocorrelation. Skipping...')
continue
med_auto_nsamples = {bl[2]: np.median(n) for bl, n in auto_nsamples.items()}
if not any([np.median(nsamples[bl]) > MIN_SAMP_FRAC * med_auto_nsamples[bl[2]] for bl in nsamples]):
print('\tNo polarization has enough nsamples to be worth filtering. Skipping...')
continue
# Combine flags across pols, including prior flags if enabled
flags_here = flags[antpair + ('ee',)] | flags[antpair + ('nn',)]
if APPLY_PRIOR_FLAGS:
flags_here |= prior_flags
# OR in per-baseline where_inpainted flags to make sure they proagate into full_day_rfi_round_4 and full_day_rfi_round_5
if APPLY_WHERE_INPAINTED_FLAGS:
wip_flags = np.zeros_like(flags_here)
where_inpainted_file = single_bl_file.replace('.inpainted.uvh5', WHERE_INPAINTED_SUFFIX)
if not os.path.exists(where_inpainted_file):
raise FileNotFoundError(
f'APPLY_WHERE_INPAINTED_FLAGS is True but no file found at {where_inpainted_file}')
uvf_wi = UVFlag(where_inpainted_file)
# uvf_wi.flag_array has shape (Ntimes, Nfreqs, Npols); collapse pol axis with np.any
# to match the ee|nn combine semantics used for flags_here.
wip_flags = np.any(uvf_wi.flag_array, axis=-1)
if wip_flags.shape != flags_here.shape:
raise ValueError(
f'where_inpainted flags shape {wip_flags.shape} does not match '
f'flags_here shape {flags_here.shape} for {where_inpainted_file}')
flags_here |= wip_flags
del uvf_wi
# Get tslice and bands split around FM
tslices_bl, bands_bl = flag_utils.get_minimal_slices(flags_here,
freqs=data.freqs, freq_cuts=[(FM_LOW_FREQ + FM_HIGH_FREQ) * .5e6])
if SUBTRACT_POLARIZED_SOURCE:
print ("\tSubtracting off a polarized source model from the visibilities...")
subtract_polarized_models(data, flags, model_files, extra_flags=flags_here)
# Compute FR ranges for this baseline
print('\tComputing fringe rate ranges...')
(mb_frate_freqs_MHz, mb_frate_tops, mb_frate_bottoms) = compute_mb_fr_ranges(hd, antpair)
(sky_frate_freqs_MHz, sky_frate_tops, sky_frate_bottoms) = compute_sky_fr_ranges(hd, antpair)
# Decide whether to process baselines whose main beam overlaps FR = 0 ± XTALK_FRATE.
# By default these are skipped, since they often contain residual crosstalk-like
# structure at FR=0 that's probably not RFI. To process them anyway, set
# SKIP_FR0_OVERLAP_BASELINES=FALSE -- the FRF will then apply a per-channel FR=0
# notch (and a leverage correction that accounts for it) on overlap channels.
fr0_overlap = overlaps_FR0(bands_bl, mb_frate_tops, mb_frate_bottoms)
if fr0_overlap and SKIP_FR0_OVERLAP_BASELINES:
print(f'\tThis baseline overlaps FR = 0 ± {XTALK_FRATE} mHz. Skipping. '
f'Set SKIP_FR0_OVERLAP_BASELINES=FALSE to process these with a per-channel FR=0 notch in the FRF step.')
continue
if fr0_overlap and SAVE_FRF_SNR:
print(f'\tThis baseline overlaps FR = 0 ± {XTALK_FRATE} mHz. The FRF will apply a per-channel FR=0 notch.')
elif fr0_overlap:
print(f'\tWARNING: this baseline overlaps FR = 0 ± {XTALK_FRATE} mHz, but SAVE_FRF_SNR is False -- '
'no FR=0 notch will be applied (as it is part of the FRF step).')
# Process each band
filt_flags_full = np.ones((len(hd.times), len(hd.freqs)), dtype=bool)
newly_flagged = np.zeros((len(hd.times), len(hd.freqs)), dtype=bool)
dly_filt_SNR_full = np.full((len(hd.times), len(hd.freqs)), np.nan, dtype=complex)
frf_SNR_full = np.full((len(hd.times), len(hd.freqs)), np.nan, dtype=complex)
taper_2d = np.zeros((len(hd.times), len(hd.freqs)))
for tslice, band in zip(tslices_bl, bands_bl):
if (band is None) or np.all(flags_here[tslice, band]):
continue
# Extract per-pol data for this band
d_ee = data[antpair + ('ee',)][tslice, band]
d_nn = data[antpair + ('nn',)][tslice, band]
f_ee = flags[antpair + ('ee',)][tslice, band]
f_nn = flags[antpair + ('nn',)][tslice, band]
n_ee = nsamples[antpair + ('ee',)][tslice, band]
n_nn = nsamples[antpair + ('nn',)][tslice, band]
a_ee = np.abs(autos[autos.antpairs().pop() + ('ee',)][tslice, band])
a_nn = np.abs(autos[autos.antpairs().pop() + ('nn',)][tslice, band])
# Compute variance from autos
var_pI = a_ee**2 / (dt * df) / n_ee + a_nn**2 / (dt * df) / n_nn
# Form pseudo-Stokes pI
d_pI, f_pI, n_pI = hp.pstokes._combine_pol_arrays(
'ee', 'nn', 'pI', pol_convention=hd.pol_convention,
data_list=[d_ee, d_nn], flags_list=[f_ee, f_nn],
nsamples_list=[n_ee, n_nn],
x_orientation=hd.telescope.get_x_orientation_from_feeds())
if APPLY_PRIOR_FLAGS:
f_pI |= prior_flags[tslice, band]
if APPLY_WHERE_INPAINTED_FLAGS:
f_pI |= wip_flags[tslice, band]
d_pI[f_pI] = 0
# Compute SNR
SNR = d_pI / var_pI**.5
# Delay filter using dpss_matrix (pinv-based) for stability across large frequency gaps.
# dpss_leastsq and dpss_solve are faster but can produce unstable results when X^T W X
# is ill-conditioned due to large contiguous gaps in the flagging.
print(f'\tDelay filtering band {band}...')
wgts = np.where(f_pI, 0, 1).astype(float)
filter_kwargs = dict(filter_centers=[0], filter_half_widths=[FILTER_DELAY * 1e-9],
eigenval_cutoff=[EIGENVAL_CUTOFF], suppression_factors=[EIGENVAL_CUTOFF],
max_contiguous_edge_flags=len(hd.freqs), ridge_alpha=RIDGE_ALPHA)
result, _, info = fourier_filter(hd.freqs[band], SNR, wgts=wgts,
mode='dpss_matrix', **filter_kwargs)
dly_filt_SNR = SNR - result
# Per-integration weighted leverage correction
X = dpss_operator(hd.freqs[band], [0],
filter_half_widths=[FILTER_DELAY / 1e9], eigenval_cutoff=[EIGENVAL_CUTOFF])[0]
correction_cache = {}
for i in range(SNR.shape[0]):
w = wgts[i]
if not np.any(w):
continue
cache_key = f_pI[i].tobytes()
if cache_key not in correction_cache:
XtWX = (X.T * w) @ X
XtWX += np.diag(np.diag(XtWX)) * RIDGE_ALPHA
if np.all(np.isclose(XtWX.imag, 0)):
XtWX = np.real(XtWX)
try:
P = np.linalg.pinv(XtWX, hermitian=True) @ (X.T * w)
lev = np.real(np.sum(X * P.T, axis=1))
lev = np.clip(lev, 0.0, LEVERAGE_CAP)
correction = (1 - lev)**.5
correction_cache[cache_key] = np.where(np.isfinite(correction) & (lev > 0) & (lev < 1), correction, np.nan)
except np.linalg.LinAlgError:
correction_cache[cache_key] = np.full(X.shape[0], np.nan)
dly_filt_SNR[i] /= correction_cache[cache_key]
# update larger arrays and account for non-finite SNR (degenerate leverage)
newly_flagged[tslice, band] |= ~np.isfinite(dly_filt_SNR) & ~f_pI
f_pI |= newly_flagged[tslice, band]
dly_filt_SNR[newly_flagged[tslice, band]] = 0
filt_flags_full[tslice, band] = f_pI
dly_filt_SNR_full[tslice, band] = dly_filt_SNR
if SAVE_FRF_SNR:
# The FR=0 notch and main-beam FRF act on a copy of the delay-filtered
# data; dly_filt_SNR (and dly_filt_SNR_full above) are left untouched.
times_in_seconds = (hd.times[tslice] - hd.times[tslice][0]) * 24 * 3600
frf_SNR = copy.deepcopy(dly_filt_SNR)
# ---- FR=0 notch on frf_SNR, per-channel (always, when SAVE_FRF_SNR) ----
# The cache stores (correction_per_t, P_fr0); P_fr0 is reused below to
# compute the leverage of the composed FR=0 + main-beam operator.
print(f'\tFR=0 notch filtering band {band}...')
Xt_fr0 = dpss_operator(hd.times[tslice] * 24 * 3600, filter_centers=[0],
filter_half_widths=[XTALK_FRATE / 1000], eigenval_cutoff=[EIGENVAL_CUTOFF])[0]
fr0_correction_cache = {}
for chan, (fr_low, fr_high) in enumerate(zip(mb_frate_bottoms[band], mb_frate_tops[band])):
w_fr0 = np.where(f_pI[:, chan], 0, 1).astype(float)
if not np.any(w_fr0):
continue
cache_key_fr0 = f_pI[:, chan].tobytes()
if cache_key_fr0 not in fr0_correction_cache:
XtWXt_fr0 = np.dot(Xt_fr0.conj().T * w_fr0, Xt_fr0)
XtWXt_fr0 += np.diag(np.diag(XtWXt_fr0)) * RIDGE_ALPHA
try:
P_fr0 = np.linalg.pinv(XtWXt_fr0, hermitian=True) @ (Xt_fr0.conj().T * w_fr0)
lev_fr0 = np.real(np.sum(Xt_fr0 * P_fr0.T, axis=1))
lev_fr0 = np.clip(lev_fr0, 0.0, LEVERAGE_CAP)
correction_fr0 = (1 - lev_fr0)**.5
correction_fr0 = np.where(
np.isfinite(correction_fr0) & (lev_fr0 > 0) & (lev_fr0 < 1),
correction_fr0, np.nan)
fr0_correction_cache[cache_key_fr0] = (correction_fr0, P_fr0)
except np.linalg.LinAlgError:
fr0_correction_cache[cache_key_fr0] = (np.full(Xt_fr0.shape[0], np.nan), None)
result_fr0, _, _ = fourier_filter(times_in_seconds,
frf_SNR[:, chan:chan+1],
wgts=np.where(f_pI[:, chan:chan+1], 0, 1),
filter_centers=[0], filter_half_widths=[XTALK_FRATE / 1000],
mode='dpss_leastsq', eigenval_cutoff=[EIGENVAL_CUTOFF],
suppression_factors=[EIGENVAL_CUTOFF],
max_contiguous_edge_flags=len(hd.times[tslice]),
filter_dims=0, ridge_alpha=RIDGE_ALPHA)
frf_SNR[:, chan:chan+1] -= result_fr0
frf_SNR[:, chan] /= fr0_correction_cache[cache_key_fr0][0]
# The per-time divide by sqrt(1 - lev_fr0) above can introduce NaN at
# degenerate-leverage times. Zero those before the main FR fit so they
# don't poison fourier_filter's masked least-squares (NaN at mask=True
# positions propagates through Xy = amat[mask].T @ y[mask]).
newly_flagged[tslice, band] |= ~np.isfinite(frf_SNR) & ~f_pI
f_pI |= newly_flagged[tslice, band]
frf_SNR[newly_flagged[tslice, band]] = 0
filt_flags_full[tslice, band] = f_pI
# ---- Per-channel FR filter at the main-beam center ----
print(f'\tFR filtering band {band}...')
for chan, (fr_low, fr_high) in enumerate(zip(mb_frate_bottoms[band], mb_frate_tops[band])):
fr_center = (fr_low + fr_high) / 2 / 1000 # mHz -> Hz
fr_halfwidth = (fr_high - fr_low) / 2 / 1000 # mHz -> Hz
# filter directly at the main-beam fringe rate center
result, _, _ = fourier_filter(times_in_seconds, frf_SNR[:, chan:chan+1],
wgts=np.where(f_pI[:, chan:chan+1], 0, 1), filter_centers=[fr_center],
filter_half_widths=[fr_halfwidth], mode='dpss_leastsq',
eigenval_cutoff=[EIGENVAL_CUTOFF], suppression_factors=[EIGENVAL_CUTOFF],
max_contiguous_edge_flags=len(data.times), filter_dims=0)
Xt = dpss_operator(hd.times[tslice] * 24 * 3600, filter_centers=[fr_center],
filter_half_widths=[fr_halfwidth], eigenval_cutoff=[EIGENVAL_CUTOFF])[0]
W = np.where(f_pI[:, chan], 0, 1).astype(float)
XtWXt = np.dot(Xt.conj().T * W, Xt)
XtWXt += np.diag(np.diag(XtWXt)) * RIDGE_ALPHA
try:
P = np.linalg.pinv(XtWXt, hermitian=True) @ (Xt.conj().T * W)
# Per-time noise variance after the composed operator
# M = H_main @ D_fr0 @ (I - H_fr0)
# acts on raw unit-variance noise. Factored form:
# Q = P*D_fr0 - (P*D_fr0 @ Xt_fr0) @ P_fr0
# var_t = (Xt @ S @ Xt^H)_tt, S = Q @ Q^H
# avoids ever materialising the n_t x n_t hat matrices.
cache_entry = fr0_correction_cache.get(f_pI[:, chan].tobytes())
if cache_entry is not None and cache_entry[1] is not None:
correction_fr0, P_fr0 = cache_entry
D_fr0 = np.where(np.isfinite(correction_fr0), 1.0 / correction_fr0, 0.0)
PD = P * D_fr0[None, :]
Q = PD - (PD @ Xt_fr0) @ P_fr0
else:
Q = P
S = Q @ Q.conj().T
var_t = np.real(np.einsum('tb,bc,tc->t', Xt, S, Xt.conj()))
lev_t_correction = np.where(var_t > 0, np.sqrt(var_t), np.nan)
except np.linalg.LinAlgError:
lev_t_correction = np.full(Xt.shape[0], np.nan)
frf_SNR[:, chan:chan+1] = np.where(f_pI[:, chan:chan+1], 0, result / lev_t_correction[:, None])
# account for non-finite SNR again
frf_SNR_full[tslice, band] = frf_SNR
newly_flagged[tslice, band] |= ~np.isfinite(frf_SNR) & ~f_pI
# build 2D taper for plotting
band_ntimes = tslice.stop - tslice.start
taper_2d[tslice, band] = blackmanharris(band_ntimes)[:, np.newaxis]
if np.any(newly_flagged):
print(f'\tFlagging {np.sum(newly_flagged)} waterfall pixels due to NaNs and infs, typically due to degenerate leverage.')
filt_flags_full |= newly_flagged
if np.all(filt_flags_full):
print(f'\t{antpair} is entirely flagged.')
continue
# Now produce figures to display later
delay_fr_figs.append(plot_fr_waterfall(
dly_filt_SNR_full, filt_flags_full, taper_2d, hd.freqs, hd.times,
f'{antpair} Delay-Filtered pI',
mb_frate_freqs_MHz=mb_frate_freqs_MHz, mb_frate_tops=mb_frate_tops, mb_frate_bottoms=mb_frate_bottoms,
sky_frate_freqs_MHz=sky_frate_freqs_MHz, sky_frate_tops=sky_frate_tops, sky_frate_bottoms=sky_frate_bottoms))
dly_waterfall_figs.append(plot_time_freq_waterfall(
dly_filt_SNR_full, filt_flags_full, hd.freqs, hd.times, hd.lsts,
f'{antpair} Delay-Filtered pI'))
dly_histogram_figs.append(plot_snr_histograms(
dly_filt_SNR_full, filt_flags_full, f'{antpair} Delay-Filtered pI'))
if SAVE_FRF_SNR:
waterfall_figs.append(plot_time_freq_waterfall(
frf_SNR_full, filt_flags_full, hd.freqs, hd.times, hd.lsts,
f'{antpair} Delay+FR-Filtered pI (+ FR=0 notch)'))
histogram_figs.append(plot_snr_histograms(
frf_SNR_full, filt_flags_full, f'{antpair} Delay+FR-Filtered pI (+ FR=0 notch)'))
# Write SNR outputs
if SAVE_DLY_SNR or SAVE_FRF_SNR:
hd.select(polarizations=['ee'])
hd.polarization_array[0] = utils.polstr2num('pI')
for bl in list(data.keys()):
if bl[2] != 'ee':
del data[bl]
del flags[bl]
for to_save, snr_full, suffix, label in [(SAVE_DLY_SNR, dly_filt_SNR_full, DLY_SNR_SUFFIX, 'delay-filtered'),
(SAVE_FRF_SNR, frf_SNR_full, FRF_SNR_SUFFIX, 'delay-filtered and FRFed')]:
if not to_save:
continue
data[antpair + ('ee',)] = np.where(np.isfinite(snr_full), snr_full, 0)
flags[antpair + ('ee',)] = filt_flags_full
hd.update(data=data, flags=flags)
outfile = single_bl_file.replace('.uvh5', suffix)
print(f'\tWriting {label} results to {outfile}')
hd.write_uvh5(outfile, clobber=True)
Now loading /lustre/aoc/projects/hera/h6c-analysis/IDR3_2/2459920/single_baseline_files/zen.2459920.baseline.0_4.sum.smooth_calibrated.red_avg.inpainted.uvh5
Subtracting off a polarized source model from the visibilities...
Computing fringe rate ranges...
Delay filtering band slice(np.int64(0), np.int64(332), None)...
invalid value encountered in divide
Delay filtering band slice(np.int64(502), np.int64(1500), None)...
invalid value encountered in divide
Writing delay-filtered results to /lustre/aoc/projects/hera/h6c-analysis/IDR3_2/2459920/single_baseline_files/zen.2459920.baseline.0_4.sum.smooth_calibrated.red_avg.inpainted.pI_DLYFILT_SNR.uvh5
Figure 1: Delay-Filtered pI SNR in Fringe Rate Space¶
In [11]:
for fig in delay_fr_figs:
display(fig)
Figure 2: Delay-Filtered pI SNR Waterfall¶
In [12]:
for fig in dly_waterfall_figs:
display(fig)
Figure 3: Delay-Filtered pI SNR Histogram¶
In [13]:
for fig in dly_histogram_figs:
display(fig)
Figure 4: Delay+FR Filtered pI SNR Waterfall¶
In [14]:
for fig in waterfall_figs:
display(fig)
Figure 5: Delay+FR Filtered pI SNR Histogram¶
In [15]:
for fig in histogram_figs:
display(fig)
Metadata¶
In [16]:
for repo in ['hera_cal', 'hera_qm', 'hera_filters', 'hera_notebook_templates', 'hera_pspec', 'pyuvdata', 'numpy']:
exec(f'from {repo} import __version__')
print(f'{repo}: {__version__}')
hera_cal: 3.7.8.dev62+gf11268d54 hera_qm: 2.2.1.dev12+g95ecc30f0 hera_filters: 0.1.9.dev3+ged4deb46a
hera_notebook_templates: 0.0.1.dev1457+gb078d937f hera_pspec: 0.4.3.dev99+g0e4b0e22b pyuvdata: 3.2.5 numpy: 2.3.5
In [17]:
print(f'Finished execution in {(time.time() - tstart) / 60:.2f} minutes.')
Finished execution in 3.64 minutes.