Single Baseline LST-Stacker and Re-Inpainter¶
by Josh Dillon and Tyler Cox, last updated April 24, 2026
This notebook performs LST-stacking (a.k.a. LST-binning) of whole-JD, single baseline, all pol files. Most parameters are controlled by a toml config file, such as this one. In addition al single-baseline files, this notebook also requires UVFlag-compatible where_inpainted files which tell us where inpainting was previously done.
In addition to LST-stacking, which includes rephasing to a common grid, this notebook also performs re-inpainting. Data that are outliers among the other nights (in terms of a high modified $z$-score) and had previously been inpainted are now re-inpainted (on a whole band, single integration basis) using information from other nights, as well as feathering to prevent discontinuities. Next, a bit of "ex-painting" is done to ensure that all nights span the same frequency range (or are completely flagged). This is again informed by what's going on on other nights. Despite the potentially large amount of inpainting, inpainted data are considered to have Nsamples=0 in the final LST-stacked data products.
Finally, this notebook performs an optional per-night FR=0 filter, under the theory that per-night FR=0 systematics might vary from night to night in a way that can be mitigated with per-night filtering. This data product is saved separately.
Here's a set of links to skip to particular figures and tables:
• Figure 1: East-Polarized LST-Stacked Amplitude, Phase, and Nsamples after Re-Inpainting¶
• Figure 2: North-Polarized LST-Stacked Amplitude, Phase, and Nsamples after Re-Inpainting¶
• Figure 3: Modified z-Score Across Nights, Before and After Re-Inpainting¶
• Figure 4: Night with Highest Modified z-Score After Re-Inpainting¶
• Figure 5: Per-Night $z^2$-Score Histograms for Not-Inpainted Data¶
• Figure 6: Per-Night $z^2$-Score Histograms for Inpainted Data¶
• Figure 7: Time-Averaged $z^2$-Score vs Frequency¶
• Figure 8: Frequency-Averaged $z^2$-Score vs LST¶
import time
tstart = time.time()
!hostname
!date
herapost004
Fri May 15 22:57:42 MDT 2026
import os
import resource
os.environ['HDF5_USE_FILE_LOCKING'] = 'FALSE'
# Thread configuration: set BLAS thread count BEFORE importing numpy
# to avoid oversubscription when using ThreadPoolExecutor
NUM_THREADS = int(os.environ.get('NUM_THREADS', min(os.cpu_count() or 1, 16)))
BLAS_THREADS = str(max(1, (os.cpu_count() or 1) // NUM_THREADS))
os.environ.setdefault('OMP_NUM_THREADS', BLAS_THREADS)
os.environ.setdefault('MKL_NUM_THREADS', BLAS_THREADS)
os.environ.setdefault('OPENBLAS_NUM_THREADS', BLAS_THREADS)
import h5py
import hdf5plugin # REQUIRED to have the compression plugins available
import numpy as np
import scipy
import copy
import toml
import threading
from concurrent.futures import FIRST_COMPLETED, ProcessPoolExecutor, ThreadPoolExecutor, wait
import multiprocessing
import itertools
from astropy import units
from functools import reduce
import matplotlib.pyplot as plt
import matplotlib
import warnings
import hashlib
import tempfile
from collections import defaultdict
from pyuvdata import UVData
from hera_cal import lst_stack, utils, io, flag_utils
from hera_cal.lst_stack import calibration
from hera_cal.frf import sky_frates, get_FR_buffer_from_spectra
from hera_cal.lst_stack.binning import SingleBaselineStacker
from hera_qm.time_series_metrics import true_stretches
from hera_filters.dspec import fourier_filter, dpss_operator, sparse_linear_fit_2D
from IPython.display import display, HTML
%matplotlib inline
display(HTML("<style>.container { width:100% !important; }</style>"))
_ = np.seterr(all='ignore') # get rid of red warnings
%config InlineBackend.figure_format = 'retina'
print(f'NUM_THREADS = {NUM_THREADS}, BLAS_THREADS = {BLAS_THREADS}')
NUM_THREADS = 16, BLAS_THREADS = 1
# this enables better memory management on linux
import gc
import ctypes
def gc_and_malloc_trim():
gc.collect()
try:
ctypes.CDLL('libc.so.6').malloc_trim(0)
except OSError:
pass
def print_peak_rss(label):
peak_gb = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1024**2
# current RSS from /proc/self/status (Linux) or fall back to peak-only
try:
with open('/proc/self/status') as _f:
for line in _f:
if line.startswith('VmRSS:'):
cur_gb = int(line.split()[1]) / 1024**2
break
except FileNotFoundError:
cur_gb = peak_gb # macOS: no /proc
elapsed_min = (time.time() - tstart) / 60.0
print(f"[mem] {label}: current RSS = {cur_gb:.2f} GB, peak RSS = {peak_gb:.2f} GB, elapsed = {elapsed_min:.2f} min")
Parse Options¶
toml_file = os.environ.get('TOML_FILE', '/lustre/aoc/projects/hera/h6c-analysis/IDR3/src/hera_pipelines/pipelines/h6c/idr3/v1/lstbin/single_bl_lst_stack.toml')
print(f'toml_file = "{toml_file}"')
baseline_string = os.environ.get('BASELINE_STRING', None)
print(f'baseline_string = "{baseline_string}"')
PRELIMINARY = (os.environ.get('PRELIMINARY', "FALSE").upper() == "TRUE")
print(f'PRELIMINARY = {PRELIMINARY}')
toml_file = "/lustre/aoc/projects/hera/h6c-analysis/IDR3/src/hera_pipelines/pipelines/h6c/idr3/v1/lstbin/lststack_131_nights/single_bl_lst_stack.toml" baseline_string = "124_0" PRELIMINARY = False
# get options from toml file, print them out, and update globals
toml_options = toml.load(toml_file)
print(f"Now setting the following global variables from {toml_file}:\n")
globals().update({'lst_branch_cut': toml_options['FILE_CFG']['lst_branch_cut']})
print(f"lst_branch_cut = {lst_branch_cut}")
globals().update({'where_inpainted_file_rules': toml_options['FILE_CFG']['where_inpainted_file_rules']})
print(f"where_inpainted_file_rules = {where_inpainted_file_rules}")
if PRELIMINARY:
# this is used for an initial stacking of a handful of baselines, which are then used for LSTCal
toml_options['LST_STACK_OPTS']['FNAME_FORMAT'] = toml_options['LST_STACK_OPTS']['FNAME_FORMAT'].replace('.sum.uvh5', '.preliminary.sum.uvh5')
for key, val in toml_options['LSTCAL_OPTS'].items():
if isinstance(val, str):
print(f'{key} = "{val}"')
else:
print(f'{key} = {val}')
globals().update(toml_options['LSTCAL_OPTS'])
for key, val in toml_options['LST_STACK_OPTS'].items():
if isinstance(val, str):
print(f'{key} = "{val}"')
else:
print(f'{key} = {val}')
globals().update(toml_options['LST_STACK_OPTS'])
Now setting the following global variables from /lustre/aoc/projects/hera/h6c-analysis/IDR3/src/hera_pipelines/pipelines/h6c/idr3/v1/lstbin/lststack_131_nights/single_bl_lst_stack.toml:
lst_branch_cut = 5.2
where_inpainted_file_rules = [['.reinpainted.uvh5', '.where_reinpainted.h5']]
NBLS_FOR_LSTCAL = 30
RUN_AMPLITUDE_CAL = True
RUN_TIP_TILT_PHASE_CAL = True
RUN_CROSS_POL_PHASE_CAL = True
INCLUDE_AUTOS = False
FREQ_SMOOTHING_SCALE = 30.0
TIME_SMOOTHING_SCALE = 2500
BLACKLIST_TIMESCALE_FACTOR = 10
BLACKLIST_RELATIVE_ERROR_THRESH = 0.1
BLACKLIST_NITER = 5
WHERE_INPAINTED_WGTS = 0.0001
LSTCAL_FNAME_FORMAT = "single_bl_reinpaint_1000ns/zen.{night}.lstcal.hdf5"
MAX_BL_LENGTH = 1e+100
MIN_AVG_REDUNDANCY = 0.0
OUTDIR = "/lustre/aoc/projects/hera/h6c-analysis/IDR3/lststack-outputs-131-nights"
FNAME_FORMAT = "single_bl_reinpaint_1000ns/zen.LST.baseline.{bl_str}.sum.uvh5"
USE_LSTCAL_GAINS = True
FM_LOW_FREQ = 87.5
FM_HIGH_FREQ = 108.0
EIGENVAL_CUTOFF = 1e-12
INPAINT_DELAY = 1000
AUTO_INPAINT_DELAY = 100
REINPAINT_AUTOS = False
INPAINT_WIDTH_FACTOR = 0.5
INPAINT_ZERO_DIST_WEIGHT = 0.01
MIN_NIGHTS_PER_BIN = 3
MAX_CHANNEL_NIGHTLY_FLAG_FRAC = 0.25
MOD_Z_TO_REINPAINT = 5
AUTO_FR_SPECTRUM_FILE = "/lustre/aoc/projects/hera/zmartino/hera_frf/spectra_cache/spectra_cache_hera_auto.h5"
GAUSS_FIT_BUFFER_CUT = 1e-05
CG_TOL = 1e-06
CG_ITER_LIM = 500
FR0_FILTER = True
FR0_HALFWIDTH = 0.01
FR0_FILT_AUTOS = False
# figure out outfiles
OUTFILE = os.path.join(OUTDIR, FNAME_FORMAT.replace('{bl_str}', baseline_string))
print(f'OUTFILE = "{OUTFILE}"')
if FR0_FILTER:
FR0_FILT_OUTFILE = OUTFILE.replace('.uvh5', '.FR0filt.uvh5')
print(f'FR0_FILT_OUTFILE = "{FR0_FILT_OUTFILE}"')
# if necessary, create the output directory
if not os.path.exists(os.path.dirname(OUTFILE)):
os.makedirs(os.path.dirname(OUTFILE), exist_ok=True)
OUTFILE = "/lustre/aoc/projects/hera/h6c-analysis/IDR3/lststack-outputs-131-nights/single_bl_reinpaint_1000ns/zen.LST.baseline.124_0.sum.uvh5" FR0_FILT_OUTFILE = "/lustre/aoc/projects/hera/h6c-analysis/IDR3/lststack-outputs-131-nights/single_bl_reinpaint_1000ns/zen.LST.baseline.124_0.sum.FR0filt.uvh5"
Load Data¶
# build configurator from toml file
configurator = lst_stack.config.LSTBinConfiguratorSingleBaseline.from_toml(toml_file)
if not PRELIMINARY and USE_LSTCAL_GAINS:
configurator.build_visfile_to_calfile_map(
os.path.join(OUTDIR, LSTCAL_FNAME_FORMAT)
)
cal_file_loader = calibration.load_single_baseline_lstcal_gains
else:
cal_file_loader = None
auto_baseline_string = [s for s in configurator.bl_to_file_map if (p := s.split('_'))[0] == p[1]][0]
# get key data properties from a singe file
hd = io.HERAData(configurator.bl_to_file_map[auto_baseline_string][0])
pol_convention = hd.pol_convention
df = np.median(np.diff(hd.freqs))
dlst = np.median(np.diff(hd.lsts))
lst_grid = lst_stack.config.make_lst_grid(dlst, begin_lst=0, lst_width=(2 * np.pi))
lst_bin_edges = np.concatenate([lst_grid - dlst / 2, (lst_grid[-1] + dlst / 2)[None]])
low_band = slice(0, np.searchsorted(hd.freqs, FM_LOW_FREQ * 1e6))
high_band = slice(np.searchsorted(hd.freqs, FM_HIGH_FREQ * 1e6), len(hd.freqs))
# load or cache averaged autocorrelations for weights
cache_key = hashlib.md5((toml.dumps(toml_options) + auto_baseline_string).encode()).hexdigest()
auto_cache_file = os.path.join(OUTDIR, f'avg_autos_cache_{"preliminary_" if PRELIMINARY else ""}{cache_key}.npz')
if os.path.exists(auto_cache_file):
print(f'Loading cached averaged autos from {auto_cache_file}')
with np.load(auto_cache_file, allow_pickle=True) as cache:
assert cache['toml_hash'].item() == cache_key, 'Cache hash mismatch'
lst_avg_auto_data = cache['lst_avg_auto_data']
lst_avg_auto_flags = cache['lst_avg_auto_flags']
lst_avg_auto_nsamples = cache['lst_avg_auto_nsamples']
slice_kept = slice(*cache['slice_kept_args'])
auto_times_in_bins = list(cache['auto_times_in_bins'])
auto_bin_lst = cache['auto_bin_lst']
else:
print(f'No cache found at {auto_cache_file}, computing averaged autos...')
autos = SingleBaselineStacker.from_configurator(configurator,
auto_baseline_string,
lst_bin_edges,
lst_branch_cut=lst_branch_cut,
where_inpainted_file_rules=where_inpainted_file_rules,
cal_file_loader=cal_file_loader,
)
lst_avg_auto_data, lst_avg_auto_flags, lst_avg_auto_nsamples = autos.average_over_nights(inpainted_data_are_samples=True)
slice_kept = copy.deepcopy(autos.slice_kept)
auto_times_in_bins = copy.deepcopy(autos.times_in_bins)
auto_bin_lst = autos.bin_lst.copy()
del autos
# save cache atomically to handle parallel race conditions
tmp_fd, tmp_path = tempfile.mkstemp(dir=OUTDIR, suffix='.npz')
os.close(tmp_fd)
np.savez(tmp_path,
toml_hash=cache_key,
lst_avg_auto_data=lst_avg_auto_data,
lst_avg_auto_flags=lst_avg_auto_flags,
lst_avg_auto_nsamples=lst_avg_auto_nsamples,
slice_kept_args=np.array([slice_kept.start, slice_kept.stop, slice_kept.step]),
auto_times_in_bins=np.array(auto_times_in_bins, dtype=object),
auto_bin_lst=auto_bin_lst)
try:
os.rename(tmp_path, auto_cache_file)
print(f'Saved averaged autos cache to {auto_cache_file}')
except OSError:
os.remove(tmp_path) # another process beat us
print(f'Cache already written by another process.')
gc_and_malloc_trim()
print_peak_rss("after SingleBaselineStacker autos load")
Loading cached averaged autos from /lustre/aoc/projects/hera/h6c-analysis/IDR3/lststack-outputs-131-nights/avg_autos_cache_cda036938bd4f9c84402373dc5a92101.npz
[mem] after SingleBaselineStacker autos load: current RSS = 1.58 GB, peak RSS = 1.58 GB, elapsed = 0.52 min
crosses = SingleBaselineStacker.from_configurator(configurator,
baseline_string,
lst_bin_edges,
lst_branch_cut=lst_branch_cut,
to_keep_slice=slice_kept,
where_inpainted_file_rules=where_inpainted_file_rules,
cal_file_loader=cal_file_loader,
)
gc_and_malloc_trim()
print_peak_rss("after SingleBaselineStacker crosses load")
Getting antpos from the first file only. This is almost always correct, but will be wrong if different files have different antenna_position arrays.
[mem] after SingleBaselineStacker crosses load: current RSS = 4.02 GB, peak RSS = 7.03 GB, elapsed = 1.06 min
# convert bin_lsts to JD, starting with the first day
lat, lon, alt = hd.telescope.location_lat_lon_alt_degrees
bin_times = utils.LST2JD(auto_bin_lst, int(auto_times_in_bins[0][0]), allow_other_jd=True, latitude=lat, longitude=lon, altitude=alt)
Perform re-inpainting of data with high modified z-score relative to other nights¶
# compute modified z-scores across nights for each night
modz_const = 2**.5 * scipy.special.erfinv(.5)
def _compute_mod_z(d, f):
ma = np.ma.array(d, mask=f)
med = np.ma.median(ma, axis=0, keepdims=True)
MAD = np.ma.median(np.abs(ma - med), axis=0, keepdims=True)
return modz_const * (ma - med) / MAD
with ThreadPoolExecutor(max_workers=NUM_THREADS) as executor:
mod_zs = list(executor.map(_compute_mod_z, crosses.data, crosses.flags))
gc_and_malloc_trim()
print_peak_rss('after computing mod_zs')
[mem] after computing mod_zs: current RSS = 5.63 GB, peak RSS = 7.03 GB, elapsed = 1.44 min
# perfrom fourier filtering with cached DPSS operators
CACHE = {}
def freq_filter(freqs, data, wgts, filter_half_widths=[INPAINT_DELAY * 1e-9], eigenval_cutoff=[EIGENVAL_CUTOFF]):
'''Thin wrapper around hera_filters.dspec.fourier_filter'''
return fourier_filter(freqs,
data,
wgts=wgts,
filter_centers=[0],
filter_half_widths=filter_half_widths,
mode='dpss_solve',
eigenval_cutoff=eigenval_cutoff,
suppression_factors=eigenval_cutoff,
max_contiguous_edge_flags=len(hd.freqs),
filter_dims=1,
cache_solver_products=False,
cache=CACHE)
# Pre-warm cache for both bands so parallel workers only read from it
for band in [low_band, high_band]:
_dummy = np.zeros(len(hd.freqs[band]))
_wgts = np.ones(len(hd.freqs[band]))
freq_filter(hd.freqs[band], _dummy, wgts=_wgts)
del _dummy, _wgts
Casting complex values to real discards the imaginary part
# re-inpaint inpainted data with high z-scores
# Get antennas that make up the baseline
ai, aj = list(map(int, baseline_string.split('_')))
IS_AUTO = (ai == aj)
if not IS_AUTO or REINPAINT_AUTOS:
def _reinpaint_one_lst_bin(d, f, n, aa, mod_z, wip):
'''Re-inpaint data for a single LST bin where modified z-scores are high.'''
enough_nights = np.sum(~f, axis=0) >= MIN_NIGHTS_PER_BIN
f[:, ~enough_nights] = True # flag bins with not enough nights
to_reip = np.zeros_like(f)
for pol in crosses.hd.pols:
pidx = crosses.hd.pols.index(pol)
# get indices for indexing into autocorrelations for weights
p1, p2 = utils.split_pol(pol)
pidx1 = crosses.hd.pols.index(utils.join_pol(p1, p1))
pidx2 = crosses.hd.pols.index(utils.join_pol(p2, p2))
for band in [low_band, high_band]:
for tidx in range(mod_z.shape[0]):
if np.any(wip[tidx, band, pidx] &
(~f[tidx, band, pidx]) &
enough_nights[band, pidx] &
(np.abs(mod_z.data[tidx, band, pidx]) > MOD_Z_TO_REINPAINT)):
to_reip[tidx, band, pidx] = True
if np.any(to_reip[:, band, pidx]):
# weighted average data across nights, excluding nights that are to be reinpainted
nights_with_reip = np.any(to_reip[:, band, pidx], axis=1)
if np.all(f[~nights_with_reip, band, pidx]):
f[:, band, :] = True # there are no unflagged nights with consistently good z-scores, so flag the whole integration
continue
avg_data_here = np.nansum(np.where(f[~nights_with_reip, band, pidx], np.nan,
d[~nights_with_reip, band, pidx] * n[~nights_with_reip, band, pidx]), axis=0)
avg_data_here /= np.sum(np.where(f[~nights_with_reip, band, pidx], 0,
n[~nights_with_reip, band, pidx]), axis=0)
# fit that average with DPSS
samples_here = np.sum([n[tidx, band, pidx] * ~f[tidx, band, pidx]
for tidx in range(n.shape[0])], axis=0)
wgts = np.where(~np.isfinite(avg_data_here), 0, aa[band, pidx1]**-1 * aa[band, pidx2]**-1 * samples_here)
avg_data_clean = np.where(~np.isfinite(avg_data_here), 0, avg_data_here)
ts_avg = true_stretches(wgts > 0)[0]
assert len(true_stretches(wgts > 0)) == 1, "Expected only one stretch of non-zero wgts"
avg_mdl = np.zeros_like(avg_data_clean)
avg_mdl[ts_avg], *_ = freq_filter(hd.freqs[band][ts_avg], avg_data_clean[ts_avg], wgts=wgts[ts_avg])
# perform re-inpainting with feathered weights
for tidx in range(d.shape[0]):
if np.any(to_reip[tidx, band, pidx]):
# figure out feathered weights
distances = flag_utils.distance_to_nearest_nonzero(~wip[tidx, band, pidx])
width = (1e-9 * INPAINT_DELAY)**-1 / df * INPAINT_WIDTH_FACTOR
rel_weights = (1 + np.exp(-np.log(INPAINT_ZERO_DIST_WEIGHT**-1 - 1) / width * (distances - width)))**-1
wgts_here = np.where(wip[tidx, band, pidx], wgts * rel_weights, wgts)
# re-inpaint with DPSS
to_fit = np.where(wip[tidx, band, pidx] | f[tidx, band, pidx],
avg_mdl, d[tidx, band, pidx])
ts = true_stretches(wgts_here > 0)[0]
assert len(true_stretches(wgts_here > 0)) == 1, "Expected only one stretch of non-zero wgts_here"
reip_mdl = np.zeros_like(to_fit)
reip_mdl[ts], *_ = freq_filter(hd.freqs[band][ts], to_fit[ts], wgts=wgts_here[ts])
d[tidx, band, pidx] = np.where(wip[tidx, band, pidx], reip_mdl, d[tidx, band, pidx])
with ThreadPoolExecutor(max_workers=NUM_THREADS) as executor:
list(executor.map(_reinpaint_one_lst_bin,
crosses.data, crosses.flags, crosses.nsamples,
lst_avg_auto_data, mod_zs, crosses.where_inpainted))
invalid value encountered in divide invalid value encountered in reciprocal
# flag channels that are too often flagged across nights
fully_flagged = np.array([np.all(f, axis=0) for f in crosses.flags])
for pidx in range(fully_flagged.shape[-1]):
for band in [low_band, high_band]:
tslice = flag_utils.get_minimal_slices(fully_flagged[:, band, pidx])[0][0]
if tslice is not None:
too_often_flagged_chans = np.mean(fully_flagged[tslice, band, pidx], axis=0) > MAX_CHANNEL_NIGHTLY_FLAG_FRAC
for f in crosses.flags:
f[:, band, pidx] |= too_often_flagged_chans
# if the vast majority of the waterfall (typically defined by other pols) is flagged
if np.mean(fully_flagged[:, band, pidx]) > 0.95:
for f in crosses.flags:
f[:, band, :] = True
# Harmonize flags across polarization
def _pol_harmonize():
for band in [low_band, high_band]:
for fi, f in enumerate(crosses.flags):
if np.any(np.all(f[:, band], axis=(0, 1))):
f[:, band, :] = True
_pol_harmonize()
# Fully flag any JD that is already >95% flagged
FRAC_FLAG_THRESHOLD = 0.95
tally = defaultdict(lambda: [0, 0]) # jd -> [flagged, total]
for f, tib in zip(crosses.flags, crosses.times_in_bins):
for i, jd in enumerate(np.floor(tib).astype(int)):
tally[int(jd)][0] += int(f[i].sum())
tally[int(jd)][1] += f[i].size
to_wipe = {jd: fl / tot for jd, (fl, tot) in tally.items() if fl / tot > FRAC_FLAG_THRESHOLD}
for f, tib, n in zip(crosses.flags, crosses.times_in_bins, crosses.nsamples):
mask = np.isin(np.floor(tib).astype(int), list(to_wipe))
f[mask] = True
n[mask] = 0
print(f"Wiped {len(to_wipe)}/{len(tally)} JDs >{FRAC_FLAG_THRESHOLD * 100:.0f}% flagged:")
for jd, frac in sorted(to_wipe.items(), key=lambda kv: -kv[1]):
print(f" {jd}: {frac*100:.2f}%")
_pol_harmonize()
# Figure out entirely flagged integrations
all_flagged = np.array([np.all(f, axis=0) for f in crosses.flags])
Wiped 5/5 JDs >95% flagged: 2459860: 100.00% 2459869: 100.00% 2459876: 100.00% 2459862: 100.00% 2459917: 100.00%
def _max_abs_or_nan(modz, tib=None):
'''Max |modz| along axis 0, treating masked entries as -inf and returning NaN
where all entries along axis 0 are masked. If `tib` is provided, also returns
per-pixel attribution: the floored JD of the night holding the maximum.
'''
if modz.shape[0] == 0:
nan_slice = np.full(modz.shape[1:], np.nan)
return nan_slice if tib is None else (nan_slice, nan_slice)
abs_data = np.where(np.ma.getmaskarray(modz), -np.inf, np.abs(modz.data)) # -inf can never be max
idx = np.argmax(abs_data, axis=0)
max_abs = np.take_along_axis(abs_data, idx[None], axis=0)[0]
max_out = np.where(max_abs == -np.inf, np.nan, max_abs)
if tib is None:
return max_out
return max_out, np.where(np.isnan(max_out), np.nan, np.floor(tib[idx]))
# compute summary statistics for modified z-scores, then delete them to save memory
with ThreadPoolExecutor(max_workers=NUM_THREADS) as executor:
_max_list = list(executor.map(_max_abs_or_nan, mod_zs))
max_mod_z_before = np.where(all_flagged, np.nan, np.array(_max_list))
del _max_list
del mod_zs
gc_and_malloc_trim()
print_peak_rss("after re-inpainting")
[mem] after re-inpainting: current RSS = 4.89 GB, peak RSS = 7.03 GB, elapsed = 2.20 min
Expaint band edges (i.e. use other nights to extrapolate)¶
# Scoped direct-DPSS path for the band-edge expainter only. Avoids the overhead
# of fourier_filter since we're calling it many times on very small ranges of
# channels. Only caches DPSS operator, which is translation invariant.
_EXPAINT_DPSS_CACHE = {}
def _expaint_dpss_fit(freqs, data, wgts,
filter_half_widths=[INPAINT_DELAY * 1e-9],
eigenval_cutoff=[EIGENVAL_CUTOFF]):
'''Direct DPSS weighted-lstsq, matches fourier_filter(mode="dpss_solve",
ridge_alpha=0) to ~1e-10 relative. Scoped to the expaint-edges cell.'''
hw = float(filter_half_widths[0])
ec = float(eigenval_cutoff[0])
key = (len(freqs), float(freqs[1] - freqs[0]), hw, ec)
B = _EXPAINT_DPSS_CACHE.get(key)
if B is None:
B, _ = dpss_operator(freqs, [0.0], [hw], eigenval_cutoff=[ec])
B = np.ascontiguousarray(np.asarray(B))
_EXPAINT_DPSS_CACHE[key] = B
w = np.asarray(wgts, dtype=np.float64)
Bw = B * w[:, None]
A = Bw.T @ B
b = Bw.T @ data
c = np.linalg.solve(A, b)
return B @ c
def _expaint_edges_one_lst_bin(d, f, n, aa, wip):
'''Perform edge extrapolation for a single LST bin using other nights' data.'''
for pol in crosses.hd.pols:
pidx = crosses.hd.pols.index(pol)
# get indices for indexing into autocorrelations for weights
p1, p2 = utils.split_pol(pol)
pidx1 = crosses.hd.pols.index(utils.join_pol(p1, p1))
pidx2 = crosses.hd.pols.index(utils.join_pol(p2, p2))
for band in [low_band, high_band]:
d_here, f_here, n_here = d[:, band, pidx], f[:, band, pidx], n[:, band, pidx]
if np.all(f_here):
continue
night_to_last_unflagged = {}
night_to_first_unflagged = {}
for tidx in range(f.shape[0]):
if np.all(f[tidx, band, pidx]):
continue
# find the first and last unflagged channels on this night, for this band and pol
unflagged_here = ~f[tidx, band, pidx]
night_to_first_unflagged[tidx] = unflagged_here.argmax()
night_to_last_unflagged[tidx] = len(unflagged_here) - 1 - unflagged_here[::-1].argmax(axis=-1)
def _expaint_edge(to_fit_slice, nights_to_avg, tidx):
# average over nights with more data
avg_data = np.nansum(d_here[nights_to_avg, to_fit_slice] * n_here[nights_to_avg, to_fit_slice], axis=0)
avg_data /= np.sum(n_here[nights_to_avg, to_fit_slice], axis=0)
wgts = (aa[band, pidx1]**-1 * aa[band, pidx2]**-1)[to_fit_slice] # don't need nsamples, because it's flat
avg_mdl = _expaint_dpss_fit(hd.freqs[band][to_fit_slice], avg_data, wgts=wgts)
# perform re-inpainting with feathered weights
distances = flag_utils.distance_to_nearest_nonzero(~f_here[tidx, to_fit_slice])
width = (1e-9 * INPAINT_DELAY)**-1 / df * INPAINT_WIDTH_FACTOR
rel_weights = (1 + np.exp(-np.log(INPAINT_ZERO_DIST_WEIGHT**-1 - 1) / width * (distances - width)))**-1
wgts_here = np.where(f_here[tidx, to_fit_slice], wgts * rel_weights, wgts)
# re-inpaint with DPSS
to_fit = np.where(f_here[tidx, to_fit_slice], avg_mdl, d_here[tidx, to_fit_slice])
xp_mdl = _expaint_dpss_fit(hd.freqs[band][to_fit_slice], to_fit, wgts=wgts_here)
# modify data, flags, and where_inpainted arrays in place
freq_indices = np.arange(len(hd.freqs))[band][to_fit_slice]
d[tidx, freq_indices, pidx] = np.where(f_here[tidx, to_fit_slice], xp_mdl, d_here[tidx, to_fit_slice])
wip[tidx, freq_indices, pidx] = np.where(f_here[tidx, to_fit_slice], True, wip[tidx, freq_indices, pidx])
f[tidx, freq_indices, pidx] = False
# first, perform ex-painting on the bottom of the band
sorted_nights = sorted(night_to_last_unflagged.keys(), key=lambda x: night_to_last_unflagged[x], reverse=True)
target_last_unflagged = night_to_last_unflagged[sorted_nights[0]]
for i, tidx in enumerate(sorted_nights):
if night_to_last_unflagged[tidx] == target_last_unflagged:
continue # no additional extrapolation necessary
# figure out which channels to fit on this particular night to get a good model to expaint with
Nchans_to_fit = (target_last_unflagged - night_to_last_unflagged[tidx])
if Nchans_to_fit < (2 / (INPAINT_DELAY * 1e-9) / df):
Nchans_to_fit += int(np.ceil(2 / (INPAINT_DELAY * 1e-9) / df))
else:
Nchans_to_fit *= 2
to_fit_slice = slice(target_last_unflagged - Nchans_to_fit + 1, target_last_unflagged + 1)
_expaint_edge(to_fit_slice, sorted_nights[:i], tidx)
# now, perform ex-painting on the top of the band
sorted_nights = sorted(night_to_first_unflagged.keys(), key=lambda x: night_to_first_unflagged[x])
target_first_unflagged = night_to_first_unflagged[sorted_nights[0]]
for i, tidx in enumerate(sorted_nights):
if night_to_first_unflagged[tidx] == target_first_unflagged:
continue # no additional extrapolation necessary
# figure out which channels to fit on this particular night to get a good model to expaint with
Nchans_to_fit = (night_to_first_unflagged[tidx] - target_first_unflagged)
if Nchans_to_fit < (2 / (INPAINT_DELAY * 1e-9) / df):
Nchans_to_fit += int(np.ceil(2 / (INPAINT_DELAY * 1e-9) / df))
else:
Nchans_to_fit *= 2
to_fit_slice = slice(target_first_unflagged, target_first_unflagged + Nchans_to_fit)
_expaint_edge(to_fit_slice, sorted_nights[:i], tidx)
return d, f, wip
# Parallelize over bins with processes (fork): workers inherit this cell
# via copy-on-write, so _expaint_edges_one_lst_bin needn't be importable.
# Cap at 8 (benchmarks: diminishing returns past n=8; threads measured
# strictly slower due to OpenBLAS cross-thread contention).
# Bound futures in flight to 2*N_workers to prevent large memory overhead.
_n_proc = min(NUM_THREADS, 8)
_max_inflight = 2 * _n_proc
_fork_ctx = multiprocessing.get_context("fork")
_tasks = enumerate(zip(crosses.data, crosses.flags, crosses.nsamples,
lst_avg_auto_data, crosses.where_inpainted))
with ProcessPoolExecutor(max_workers=_n_proc, mp_context=_fork_ctx) as executor:
_submit = lambda n: {executor.submit(_expaint_edges_one_lst_bin, *a): i
for i, a in itertools.islice(_tasks, n)}
_pending = _submit(_max_inflight)
while _pending:
_done, _ = wait(_pending, return_when=FIRST_COMPLETED)
for _fut in _done:
_i = _pending.pop(_fut)
_d, _f, _wip = _fut.result()
crosses.data[_i][...] = _d
crosses.flags[_i][...] = _f
crosses.where_inpainted[_i][...] = _wip
_pending.update(_submit(len(_done)))
gc_and_malloc_trim()
print_peak_rss("after expaint band edges")
[mem] after expaint band edges: current RSS = 4.89 GB, peak RSS = 7.03 GB, elapsed = 2.29 min
2D-informed expainting to get even band edges¶
FR_CENTER_AND_HW_CACHE = {}
_FR_CACHE_LOCK = threading.Lock()
def cache_fr_center_and_hw(hd, antpair, tslice, band):
'''Figure out the range of FRs in Hz spanned for a given band and tslice, buffered by the size of the autocorrelation FR kernel,
and stores the value in FR_CENTER_AND_HW_CACHE (if it hasn't already been computed.'''
if (tslice is not None) and (band is not None) and ((antpair, tslice, band) not in FR_CENTER_AND_HW_CACHE):
with _FR_CACHE_LOCK:
# double-check after acquiring lock
if (antpair, tslice, band) not in FR_CENTER_AND_HW_CACHE:
# calculate fringe rate center and half-width and then update cache
fr_buffer = get_FR_buffer_from_spectra(AUTO_FR_SPECTRUM_FILE, hd.times[tslice], hd.freqs[band],
gauss_fit_buffer_cut=GAUSS_FIT_BUFFER_CUT)
hd_here = hd.select(inplace=False, frequencies=hd.freqs[band])
fr_center = list(sky_frates(hd_here)[0].values())[0] / 1e3 # converts to Hz
fr_hw = (list(sky_frates(hd_here)[1].values())[0] + fr_buffer) / 1e3
FR_CENTER_AND_HW_CACHE[(antpair, tslice, band)] = fr_center, fr_hw
def fit_2D_DPSS(data, weights, filter_delay, tslices, bands, **kwargs):
'''Fit a 2D DPSS model to all the baselines in data. The time-dimension is based on sky FRs
and the FR spectrum of the autos. fr_centers and fr_hws are drawn from FR_CENTER_AND_HW_CACHE.
Arguments:
data: datacontainer mapping baselines to complex visibility waterfalls
weights: datacontainer mapping baselines to real weight waterfalls.
filter_delay: maximum delay in ns for the 2D filter
tslices: dictionary mapping bl to time slices corresponding to low and high bands
bands: dictionary mapping bl to low band and high band frequency slices
kwargs: kwargs to pass into sparse_linear_fit_2D()
Returns:
dpss_fit: datacontainer mapping baselines to 2D DPSS models
'''
dpss_fit = copy.deepcopy(data)
for bl in data.keys():
# set to all nans by default
dpss_fit[bl] *= np.nan
for tslice, band in zip(tslices[bl], bands[bl]):
if (tslice is None) or (band is None) or np.all(weights[bl][tslice, band] == 0):
continue
# perform 2D DPSS filter
fr_center, fr_hw = FR_CENTER_AND_HW_CACHE[(bl[0:2], tslice, band)]
time_filters, _ = dpss_operator((data.times[tslice] - data.times[tslice][0]) * 3600 * 24,
[fr_center], [fr_hw], eigenval_cutoff=[EIGENVAL_CUTOFF])
freq_filters, _ = dpss_operator(data.freqs[band], [0.0], [filter_delay / 1e9], eigenval_cutoff=[EIGENVAL_CUTOFF])
fit, meta = sparse_linear_fit_2D(
data=data[bl][tslice, band],
weights=weights[bl][tslice, band],
axis_1_basis=time_filters,
axis_2_basis=freq_filters,
precondition_solver=True,
iter_lim=CG_ITER_LIM,
**kwargs,
)
dpss_fit[bl][tslice, band] = time_filters.dot(fit).dot(freq_filters.T)
return dpss_fit
nnights_unflagged = np.array([np.sum((~f[:, :, :]).astype(float), axis=0) for f in crosses.flags])
ntimes_unflagged = np.sum(nnights_unflagged, axis=0)
# perfrom feathered ex-painting on all pols, bands (top and bottom), and nights
def _expaint_2D_pol_band(pol, band):
pidx = crosses.hd.pols.index(pol)
# get indices for indexing into autocorrelations for weights
p1, p2 = utils.split_pol(pol)
pidx1 = crosses.hd.pols.index(utils.join_pol(p1, p1))
pidx2 = crosses.hd.pols.index(utils.join_pol(p2, p2))
# find the range of frequencies that could need explainting
max_unflagged = np.max(ntimes_unflagged[band, pidx])
if max_unflagged == 0:
return
first_unflagged_channel = np.where(ntimes_unflagged[band, pidx] > 0)[0][0]
first_minimally_flagged_channel = np.where(ntimes_unflagged[band, pidx] == max_unflagged)[0][0]
last_minimally_flagged_channel = np.where(ntimes_unflagged[band, pidx] == max_unflagged)[0][-1]
last_unflagged_channel = np.where(ntimes_unflagged[band, pidx] > 0)[0][-1]
tslice = flag_utils.get_minimal_slices(nnights_unflagged[:, band, pidx] == 0)[0][0]
def _expaint_2D(fslice):
'''Fits a 2D DPSS model to the data we do have, then performs feathered ex-painting'''
cache_fr_center_and_hw(hd, hd.antpairs[0], tslice, fslice)
data_here = np.array([np.nansum(np.where(f[:, fslice, pidx], np.nan,
d[:, fslice, pidx] * n[:, fslice, pidx]), axis=0)
for d, f, n in zip(crosses.data, crosses.flags, crosses.nsamples)])[tslice]
nsamples_here = np.array([np.sum(n[:, fslice, pidx] * ~f[:, fslice, pidx], axis=0)
for f, n in zip(crosses.flags, crosses.nsamples)])[tslice]
data_here = np.where(nsamples_here > 0, data_here / nsamples_here, 0.0)
wgts_here = np.array([aa[fslice, pidx1]**-1 * aa[fslice, pidx2]**-1 for aa in lst_avg_auto_data])[tslice] * nsamples_here
wgts_here = np.where(np.isfinite(wgts_here), wgts_here, 0.0)
# perform 2D DPSS filter
fr_center, fr_hw = FR_CENTER_AND_HW_CACHE[(hd.antpairs[0], tslice, fslice)]
time_filters, _ = dpss_operator((bin_times[tslice] - bin_times[tslice][0]) * 3600 * 24,
[fr_center], [fr_hw], eigenval_cutoff=[EIGENVAL_CUTOFF])
freq_filters, _ = dpss_operator(hd.freqs[fslice], [0.0], [INPAINT_DELAY / 1e9], eigenval_cutoff=[EIGENVAL_CUTOFF])
fit, meta = sparse_linear_fit_2D(data=data_here, weights=wgts_here, precondition_solver=True,
axis_1_basis=time_filters, axis_2_basis=freq_filters,
iter_lim=CG_ITER_LIM, atol=CG_TOL, btol=CG_TOL)
dpss_fit = time_filters.dot(fit).dot(freq_filters.T)
# perform feathered expainting on a per LST and per night basis
for lidx, (d, f, aa, wip) in list(enumerate(zip(crosses.data, crosses.flags, lst_avg_auto_data, crosses.where_inpainted)))[tslice]:
for tidx in range(d.shape[0]):
if np.any(f[tidx, fslice, pidx]) and not np.all(f[tidx, fslice, pidx]):
wgts = aa[fslice, pidx1]**-1 * aa[fslice, pidx2]**-1 # don't need nsamples, because it's flat
wgts[~np.isfinite(wgts)] = np.min(wgts[np.isfinite(wgts)]) # handle case where we're expainting beyond autos
distances = flag_utils.distance_to_nearest_nonzero(~f[tidx, fslice, pidx])
width = (1e-9 * INPAINT_DELAY)**-1 / df * INPAINT_WIDTH_FACTOR
rel_weights = (1 + np.exp(-np.log(INPAINT_ZERO_DIST_WEIGHT**-1 - 1) / width * (distances - width)))**-1
to_fit = np.where(f[tidx, fslice, pidx], dpss_fit[lidx - tslice.start], d[tidx, fslice, pidx])
wgts = np.where(f[tidx, fslice, pidx], wgts * rel_weights, wgts)
xp_mdl, *_ = freq_filter(hd.freqs[fslice], to_fit, wgts=wgts)
# modify data, flags, and where_inpainted arrays in place
d[tidx, fslice, pidx] = np.where(f[tidx, fslice, pidx], xp_mdl, d[tidx, fslice, pidx])
wip[tidx, fslice, pidx] = np.where(f[tidx, fslice, pidx], True, wip[tidx, fslice, pidx])
f[tidx, fslice, pidx] = False
if first_unflagged_channel < first_minimally_flagged_channel:
Nchans_to_fit = first_minimally_flagged_channel - first_unflagged_channel
if Nchans_to_fit < (2 / (INPAINT_DELAY * 1e-9) / df):
Nchans_to_fit += int(np.ceil(2 / (INPAINT_DELAY * 1e-9) / df))
else:
Nchans_to_fit *= 2
fslice = slice(band.start + first_unflagged_channel,
min(band.start + first_unflagged_channel + Nchans_to_fit, band.stop))
_expaint_2D(fslice)
if last_minimally_flagged_channel < last_unflagged_channel:
Nchans_to_fit = last_unflagged_channel - last_minimally_flagged_channel
if Nchans_to_fit < (2 / (INPAINT_DELAY * 1e-9) / df):
Nchans_to_fit += int(np.ceil(2 / (INPAINT_DELAY * 1e-9) / df))
else:
Nchans_to_fit *= 2
fslice = slice(max(band.start + last_unflagged_channel + 1 - Nchans_to_fit, band.start),
band.start + last_unflagged_channel + 1)
_expaint_2D(fslice)
pol_band_pairs = [(pol, band) for pol in crosses.hd.pols for band in [low_band, high_band]]
with ThreadPoolExecutor(max_workers=min(len(pol_band_pairs), NUM_THREADS)) as executor:
list(executor.map(lambda args: _expaint_2D_pol_band(*args), pol_band_pairs))
print_peak_rss("after 2D-informed expainting")
[mem] after 2D-informed expainting: current RSS = 5.30 GB, peak RSS = 7.03 GB, elapsed = 2.30 min
Now Actually Average Over Nights¶
lst_avg_data, lst_avg_flags, lst_avg_nsamples = crosses.average_over_nights()
gc_and_malloc_trim()
print_peak_rss("after average_over_nights")
[mem] after average_over_nights: current RSS = 5.86 GB, peak RSS = 7.03 GB, elapsed = 2.34 min
with ThreadPoolExecutor(max_workers=NUM_THREADS) as executor:
mod_zs_after = list(executor.map(_compute_mod_z, crosses.data, crosses.flags))
gc_and_malloc_trim()
print_peak_rss("after second set of mod_zs")
[mem] after second set of mod_zs: current RSS = 7.29 GB, peak RSS = 7.32 GB, elapsed = 2.71 min
# compute summary statistics for modified z-scores and per-pixel attribution
# in a single threaded pass over mod_zs_after.
with ThreadPoolExecutor(max_workers=NUM_THREADS) as executor:
_results = list(executor.map(_max_abs_or_nan, mod_zs_after, crosses.times_in_bins))
_max_list, _attr_list = zip(*_results)
max_mod_z_after = np.where(lst_avg_flags, np.nan, np.array(_max_list))
attribution_arr = np.array(_attr_list)
del _results, _max_list, _attr_list
N_JDS = 10
auto_pol_mask = np.array([p in ('ee', 'nn') for p in hd.pols])
valid = attribution_arr[:, :, auto_pol_mask]
unique_jds, counts = np.unique(valid[~np.isnan(valid)], return_counts=True)
top_jds = unique_jds[np.argsort(-counts)[:N_JDS]]
jd_to_color = {jd: matplotlib.colormaps['tab10'](rank) for rank, jd in enumerate(top_jds)}
del mod_zs_after
# Compute per-night z² summary statistics and histogram counts in a single streaming
# pass over crosses.{data, flags, times_in_bins, lsts_in_bins, where_inpainted}. The
# per-bin z-score step (np.ma.median / np.ma.std) is threaded in batches of
# NUM_THREADS so memory stays bounded to a handful of bins at a time.
_ordered_jds = [int(n) for n in configurator.nights]
_jd_to_rank = {jd: i for i, jd in enumerate(_ordered_jds)}
_n_jds = len(_ordered_jds)
_pol_indices = [crosses.hd.pols.index(p) for p in ('ee', 'nn')]
_band_slices = [low_band, high_band]
hist_bins = np.arange(0, 5 * MOD_Z_TO_REINPAINT, .2)
_jd_edges = np.arange(_n_jds + 1, dtype=np.float64) - 0.5
# Per-JD accumulators. _hist_counts axes: (inpainted, pol_ee_nn, band, jd_rank, z²_bin).
_hist_counts = np.zeros((2, 2, 2, _n_jds, len(hist_bins) - 1), dtype=np.int64)
_time_sum = np.zeros((_n_jds, crosses.hd.Nfreqs, len(crosses.hd.pols)), dtype=np.float64)
_time_cnt = np.zeros_like(_time_sum, dtype=np.int64)
_freq_chunks = [[] for _ in range(_n_jds)]
_lst_chunks = [[] for _ in range(_n_jds)]
def _compute_z2(d, f):
'''Return (|z|², mask) for one LST bin, or None for an empty bin.'''
if d.shape[0] == 0:
return None
ma = np.ma.array(d, mask=f)
z = (ma - np.ma.median(ma, axis=0, keepdims=True)) / np.ma.std(ma, axis=0, keepdims=True)
return np.abs(z.data) ** 2, np.ma.getmaskarray(z)
def _accumulate_bin(z2, z_mask, tib, lib, wip):
'''Fold one bin's z² into the per-JD accumulators.'''
ranks = np.array([_jd_to_rank.get(int(j), -1) for j in np.floor(tib)], dtype=np.int64)
valid = ~z_mask
contrib = np.where(valid, z2, 0.0)
for r in np.unique(ranks[ranks >= 0]):
rows = ranks == r
_time_sum[r] += contrib[rows].sum(axis=0)
_time_cnt[r] += valid[rows].sum(axis=0)
numer = contrib[rows].sum(axis=1)
denom = valid[rows].sum(axis=1)
_freq_chunks[r].append(np.ma.array(numer / np.maximum(denom, 1), mask=(denom == 0)))
_lst_chunks[r].append(lib[rows])
# Histograms across (band × pol ∈ {ee,nn} × inpainted ∈ {False, True}).
for bidx, band in enumerate(_band_slices):
for pidx, pol_idx in enumerate(_pol_indices):
vals = z2[:, band, pol_idx]
jd_bc = np.broadcast_to(ranks[:, None], vals.shape)
base = ~z_mask[:, band, pol_idx] & (jd_bc >= 0)
wip_bp = wip[:, band, pol_idx]
for iidx, keep in enumerate([base & ~wip_bp, base & wip_bp]):
if keep.any():
h, _, _ = np.histogram2d(jd_bc[keep].astype(np.float64), vals[keep],
bins=(_jd_edges, hist_bins))
_hist_counts[iidx, pidx, bidx] += h.astype(np.int64)
_n_bins = len(crosses.data)
with ThreadPoolExecutor(max_workers=NUM_THREADS) as executor:
for _start in range(0, _n_bins, NUM_THREADS):
_chunk = range(_start, min(_start + NUM_THREADS, _n_bins))
_results = list(executor.map(_compute_z2,
[crosses.data[i] for i in _chunk],
[crosses.flags[i] for i in _chunk]))
for i, result in zip(_chunk, _results):
if result is None:
continue
z2, z_mask = result
_accumulate_bin(z2, z_mask, crosses.times_in_bins[i],
crosses.lsts_in_bins[i], crosses.where_inpainted[i])
# Finalize per-JD dicts, keeping only JDs that had any data.
freq_avg_z2s, time_avg_z2s, lsts_per_jd, hist_counts = {}, {}, {}, {}
for r, jd in enumerate(_ordered_jds):
if not _freq_chunks[r]:
continue
freq_avg_z2s[jd] = np.ma.concatenate(_freq_chunks[r], axis=0)
time_avg_z2s[jd] = np.ma.array(_time_sum[r] / np.maximum(_time_cnt[r], 1), mask=(_time_cnt[r] == 0))
lsts_per_jd[jd] = np.concatenate(_lst_chunks[r])
hist_counts[jd] = _hist_counts[:, :, :, r, :].copy()
del _time_sum, _time_cnt, _freq_chunks, _lst_chunks, _hist_counts
gc_and_malloc_trim()
print_peak_rss("after computing summary statistics")
[mem] after computing summary statistics: current RSS = 6.81 GB, peak RSS = 8.47 GB, elapsed = 3.09 min
Visualization¶
def plot_waterfall(data, flags, nsamples, freqs, lsts, bl_label):
'''Plots data (amplitude and phase) as well as nsamples waterfalls for a baseline.'''
if np.all(flags):
print('This waterfall is entirely flagged. Nothing to plot.')
return
lsts_in_hours = np.where(lsts > lsts[-1], lsts - 2 * np.pi, lsts * 12 / np.pi)
extent = [freqs[0]/1e6, freqs[-1]/1e6, lsts_in_hours[-1], lsts_in_hours[0]]
fig, axes = plt.subplots(1, 3, figsize=(14, 10), sharex=True, sharey=True, gridspec_kw={'wspace': 0}, dpi=200)
im = axes[0].imshow(np.where(flags, np.nan, np.abs(data)), aspect='auto', norm=matplotlib.colors.LogNorm(), interpolation='none', cmap='inferno', extent=extent)
fig.colorbar(im, ax=axes[0], location='top', pad=.02).set_label(f'{bl_label}: Amplitude (Jy)', fontsize=16)
im = axes[1].imshow(np.where(flags, np.nan, np.angle(data)), aspect='auto', cmap='twilight', interpolation='none', extent=extent)
fig.colorbar(im, ax=axes[1], location='top', pad=.02).set_label(f'{bl_label}: Phase (Radians)', fontsize=16)
im = axes[2].imshow(np.where(flags, np.nan, nsamples), aspect='auto', interpolation='none', extent=extent)
fig.colorbar(im, ax=axes[2], location='top', pad=.02).set_label(f'{bl_label}: Number of Samples', fontsize=16)
plt.tight_layout()
axes[0].set_ylabel('LST (hours)')
for ax in axes:
ax.set_xlabel('Frequency (MHz)')
with warnings.catch_warnings():
warnings.simplefilter("ignore", UserWarning)
ax.set_yticklabels([f'{(int(val) if np.isclose(val, int(val)) else val) % 24:n}' for val in ax.get_yticks()])
plt.tight_layout()
def compare_max_mod_zs(max_mod_z_before, max_mod_z_after, hd, freqs, lsts, antpair):
'''Compares the maximum modified z-scores before and after re-inpainting for both polarizations.'''
lsts_in_hours = np.where(lsts > lsts[-1], lsts - 2 * np.pi, lsts * 12 / np.pi)
extent = [freqs[0]/1e6, freqs[-1]/1e6, lsts_in_hours[-1], lsts_in_hours[0]]
fig, axes = plt.subplots(1, 4, figsize=(14, 10), sharex=True, sharey=True, gridspec_kw={'wspace': 0}, dpi=200)
for i, pol in enumerate(['ee', 'nn']):
pidx = hd.pols.index(pol)
im = axes[2 * i].imshow(max_mod_z_before[:, :, pidx], aspect='auto', interpolation='none', vmin=0, vmax=MOD_Z_TO_REINPAINT * 2.0, extent=extent)
axes[2 * i + 1].imshow(max_mod_z_after[:, :, pidx], aspect='auto', interpolation='none', vmin=0, vmax=MOD_Z_TO_REINPAINT * 2.0, extent=extent)
# put label in top left corner that says polarization and before or after
axes[2 * i].text(0.03, 0.99, f'{pol} Before', transform=axes[2 * i].transAxes, ha='left', va='top', fontsize=12, bbox=dict(facecolor='white', alpha=0.5, boxstyle='round'))
axes[2 * i + 1].text(0.03, 0.99, f'{pol} After', transform=axes[2 * i + 1].transAxes, ha='left', va='top', fontsize=12, bbox=dict(facecolor='white', alpha=0.5, boxstyle='round'))
axes[0].set_ylabel('LST (Hours)')
for ax in axes:
ax.set_xlabel('Frequency (MHz)')
with warnings.catch_warnings():
warnings.simplefilter("ignore", UserWarning)
ax.set_yticklabels([f'{(int(val) if np.isclose(val, int(val)) else val) % 24:n}' for val in ax.get_yticks()])
plt.tight_layout()
# add colorbar with size 16 label
cbar = plt.colorbar(im, ax=axes, location='top', pad=.02, aspect=50, extend='max')
cbar.set_label(f'{antpair}: Maximum Modified z-Score Across Nights (unitless)', fontsize=16)
def plot_attribution(attribution_arr, top_jds, hd, freqs, lsts, antpair):
'''Plots which night has the highest modified z-score at each (LST, freq) pixel.'''
N_JDS = len(top_jds)
attribution_mapped = np.full_like(attribution_arr, dtype=int, fill_value=N_JDS)
for rank, jd in enumerate(top_jds):
attribution_mapped[attribution_arr == jd] = rank
attribution_mapped = np.ma.array(attribution_mapped, mask=np.isnan(attribution_arr))
from matplotlib.colors import ListedColormap, BoundaryNorm
base_cmap = matplotlib.colormaps['tab10']
cmap = ListedColormap([base_cmap(i) for i in range(N_JDS)] + ['black'])
cmap.set_bad('white')
norm = BoundaryNorm(np.arange(-0.5, N_JDS + 1.5, 1), cmap.N)
lsts_in_hours = np.where(lsts > lsts[-1], lsts - 2 * np.pi, lsts) * 12 / np.pi
extent = [freqs[0]/1e6, freqs[-1]/1e6, lsts_in_hours[-1], lsts_in_hours[0]]
fig, axes = plt.subplots(1, 2, figsize=(14, 10), sharex=True, sharey=True,
gridspec_kw={'wspace': 0}, dpi=200)
for i, pol in enumerate(['ee', 'nn']):
pidx = hd.pols.index(pol)
im = axes[i].imshow(attribution_mapped[:, :, pidx], aspect='auto', interpolation='none',
cmap=cmap, norm=norm, extent=extent)
axes[i].text(0.015, 0.985, pol, transform=axes[i].transAxes, ha='left', va='top',
fontsize=12, bbox=dict(facecolor='white', alpha=0.5, boxstyle='round'))
axes[i].set_xlabel('Frequency (MHz)')
axes[0].set_ylabel('LST (Hours)')
for ax in axes:
with warnings.catch_warnings():
warnings.simplefilter("ignore", UserWarning)
ax.set_yticklabels([f'{(int(val) if np.isclose(val, int(val)) else val) % 24:n}' for val in ax.get_yticks()])
plt.tight_layout()
cbar = plt.colorbar(im, ax=axes, location='top', pad=.02, aspect=50, ticks=range(N_JDS + 1))
cbar.set_ticklabels([str(int(jd)) for jd in top_jds] + ['Another JD'])
cbar.ax.tick_params(labelsize=10, rotation=45)
cbar.set_label(f'{antpair}: Night with Highest Modified z-Score', fontsize=16)
def plot_z2_histograms(hist_counts, hist_bins, configurator, top_jds, jd_to_color, antpair, inpainted=False):
'''Plots per-night z^2 histograms split by pol and band from precomputed bin counts.'''
iidx = 1 if inpainted else 0
bw = hist_bins[1] - hist_bins[0]
fig, axes = plt.subplots(2, 2, figsize=(14, 10), sharex=True, sharey=True, dpi=200)
ordered_jds = ([int(n) for n in configurator.nights if int(n) in top_jds]
+ [int(n) for n in configurator.nights if int(n) not in top_jds])
for i, jd in enumerate(ordered_jds):
if jd not in hist_counts:
continue
if jd in top_jds:
kwargs = dict(label=jd, linewidth=1, color=jd_to_color[jd])
else:
kwargs = dict(color='k', alpha=.5, linewidth=.5, zorder=1000,
label=('Other JDs' if i == len(top_jds) else None))
for pidx in range(2):
for bidx in range(2):
counts = hist_counts[jd][iidx, pidx, bidx]
total = counts.sum()
if total == 0:
continue
density = counts / (total * bw)
axes[pidx, bidx].stairs(density, hist_bins, **kwargs)
pols = ['ee', 'nn']
bands = ['Below FM', 'Above FM']
for pidx in range(2):
for bidx in range(2):
ax = axes[pidx, bidx]
ax.set_yscale('log')
ax.set_xlim(hist_bins[0], hist_bins[-1])
if bidx == 0:
ax.set_ylabel('Density')
ax.text(0.015, 0.975, f'{pols[pidx]}, {bands[bidx]}', transform=ax.transAxes,
ha='left', va='top', fontsize=12,
bbox=dict(facecolor='white', alpha=0.5, boxstyle='round'))
axes[0, -1].legend(fontsize=8)
axes[-1, 0].set_xlabel('$z^2$-score')
axes[-1, 1].set_xlabel('$z^2$-score')
title_suffix = 'Inpainted' if inpainted else 'Not Inpainted'
fig.suptitle(f'{antpair}: {title_suffix}', fontsize=16)
fig.tight_layout()
def plot_time_avg_z2(time_avg_z2s, top_jds, crosses, antpair):
'''Plots time-averaged z^2-score vs frequency for each night.'''
fig, axes = plt.subplots(2, 1, figsize=(14, 10), sharex=True, sharey=True, dpi=200)
for ax, pol in zip(axes, ['ee', 'nn']):
pidx = crosses.hd.pols.index(pol)
for jd in top_jds:
if jd in time_avg_z2s:
ax.plot(crosses.hd.freqs / 1e6, time_avg_z2s[jd][:, pidx], label=int(jd), color=jd_to_color[jd])
for n, jd in enumerate([d for d in time_avg_z2s if d not in top_jds]):
ax.plot(crosses.hd.freqs / 1e6, time_avg_z2s[jd][:, pidx], color='k', alpha=.5, lw=.5,
zorder=1000, label=('Other JDs' if n == 0 else None))
ax.legend(ncol=2, title=f'{pol}-polarized')
ax.set_ylabel('Time-Averaged $z^2$-Score')
ax.set_xlim(crosses.hd.freqs[0] / 1e6, crosses.hd.freqs[-1] / 1e6)
axes[-1].set_xlabel('Frequency (MHz)')
fig.suptitle(f'{antpair}: Time-Averaged $z^2$-Score vs Frequency', fontsize=16)
plt.tight_layout()
def plot_freq_avg_z2(freq_avg_z2s, lsts_per_jd, top_jds, crosses, antpair):
'''Plots frequency-averaged z^2-score vs LST for each night.'''
fig, axes = plt.subplots(2, 1, figsize=(14, 10), sharex=True, sharey=True, dpi=200)
for ax, pol in zip(axes, ['ee', 'nn']):
pidx = crosses.hd.pols.index(pol)
for jd in top_jds:
if jd in freq_avg_z2s:
ax.plot(lsts_per_jd[jd] * 12 / np.pi, freq_avg_z2s[jd][:, pidx], label=int(jd), color=jd_to_color[jd])
for n, jd in enumerate([d for d in freq_avg_z2s if d not in top_jds]):
ax.plot(lsts_per_jd[jd] * 12 / np.pi, freq_avg_z2s[jd][:, pidx], color='k', alpha=.5, lw=.5,
zorder=1000, label=('Other JDs' if n == 0 else None))
ax.legend(ncol=2, title=f'{pol}-polarized')
ax.set_ylabel('Frequency-Averaged $z^2$-Score')
with warnings.catch_warnings():
warnings.simplefilter("ignore", UserWarning)
ax.set_xticklabels([f'{(int(val) if np.isclose(val, int(val)) else val) % 24:n}' for val in ax.get_xticks()])
axes[-1].set_xlabel('LST (Hours)')
fig.suptitle(f'{antpair}: Frequency-Averaged $z^2$-Score vs LST', fontsize=16)
plt.tight_layout()
Figure 1: East-Polarized LST-Stacked Amplitude, Phase, and Nsamples after Re-Inpainting¶
pidx = hd.pols.index('ee')
plot_waterfall(lst_avg_data[:, :, pidx], lst_avg_flags[:, :, pidx], lst_avg_nsamples[:, :, pidx],
hd.freqs, crosses.bin_lst, crosses.hd.antpairs[0] + ('ee',))
This waterfall is entirely flagged. Nothing to plot.
Figure 2: North-Polarized LST-Stacked Amplitude, Phase, and Nsamples after Re-Inpainting¶
pidx = hd.pols.index('nn')
plot_waterfall(lst_avg_data[:, :, pidx], lst_avg_flags[:, :, pidx], lst_avg_nsamples[:, :, pidx],
hd.freqs, crosses.bin_lst, crosses.hd.antpairs[0] + ('nn',))
This waterfall is entirely flagged. Nothing to plot.
Figure 3: Modified z-Score Across Nights, Before and After Re-Inpainting¶
compare_max_mod_zs(max_mod_z_before, max_mod_z_after, crosses.hd, crosses.hd.freqs, crosses.bin_lst, crosses.hd.antpairs[0])
Figure 4: Night with Highest Modified z-Score After Re-Inpainting¶
plot_attribution(attribution_arr, top_jds, crosses.hd, crosses.hd.freqs, crosses.bin_lst, crosses.hd.antpairs[0])
Figure 5: Per-Night $z^2$-Score Histograms for Not-Inpainted Data¶
plot_z2_histograms(hist_counts, hist_bins, configurator, top_jds, jd_to_color, crosses.hd.antpairs[0], inpainted=False)
No artists with labels found to put in legend. Note that artists whose label start with an underscore are ignored when legend() is called with no argument.
Figure 6: Per-Night $z^2$-Score Histograms for Inpainted Data¶
plot_z2_histograms(hist_counts, hist_bins, configurator, top_jds, jd_to_color, crosses.hd.antpairs[0], inpainted=True)
No artists with labels found to put in legend. Note that artists whose label start with an underscore are ignored when legend() is called with no argument.
Figure 7: Time-Averaged $z^2$-Score vs Frequency¶
plot_time_avg_z2(time_avg_z2s, top_jds, crosses, crosses.hd.antpairs[0])
Figure 8: Frequency-Averaged $z^2$-Score vs LST¶
plot_freq_avg_z2(freq_avg_z2s, lsts_per_jd, top_jds, crosses, crosses.hd.antpairs[0])
gc_and_malloc_trim()
print_peak_rss("after visualization")
[mem] after visualization: current RSS = 6.98 GB, peak RSS = 8.68 GB, elapsed = 3.48 min
FR=0 Filter¶
# compute 2D DPSS smooth autocorrelations to use for weighting without introducing spectral structure
if FR0_FILTER and ((not IS_AUTO) or FR0_FILT_AUTOS):
filtered_autos = np.full_like(lst_avg_auto_data, np.nan)
for pol in ['ee', 'nn']:
pidx = hd.pols.index(pol)
tslices, fslices = flag_utils.get_minimal_slices(lst_avg_flags[:, :, pidx], freqs=hd.freqs, freq_cuts=[(FM_HIGH_FREQ + FM_LOW_FREQ) * .5e6])
for tslice, fslice in zip(tslices, fslices):
if (tslice is None) or (fslice is None):
continue
cache_fr_center_and_hw(hd, hd.antpairs[0], tslice, fslice)
autos_here = lst_avg_auto_data[tslice, fslice, pidx]
nsamples_here = lst_avg_auto_nsamples[tslice, fslice, pidx]
weights_here = autos_here**-2 * nsamples_here
weights_here = np.where(np.isfinite(weights_here), weights_here, 0)
fr_center, fr_hw = FR_CENTER_AND_HW_CACHE[(hd.antpairs[0], tslice, fslice)]
time_filters, _ = dpss_operator((bin_times[tslice] - bin_times[tslice][0]) * 3600 * 24,
[fr_center], [fr_hw], eigenval_cutoff=[EIGENVAL_CUTOFF])
freq_filters, _ = dpss_operator(hd.freqs[fslice], [0.0], [AUTO_INPAINT_DELAY / 1e9], eigenval_cutoff=[EIGENVAL_CUTOFF])
fit, meta = sparse_linear_fit_2D(data=autos_here, weights=weights_here, precondition_solver=True,
axis_1_basis=time_filters, axis_2_basis=freq_filters,
iter_lim=CG_ITER_LIM, atol=CG_TOL, btol=CG_TOL)
dpss_fit = time_filters.dot(fit).dot(freq_filters.T)
filtered_autos[tslice, fslice, pidx] = np.abs(dpss_fit)
# Perform FR=0 filter on crosses on a per-night basis
def _fr0_filter_one_jd(jd):
'''Perform FR=0 filtering for a single JD, updating crosses.data in place.'''
data_here, flags_here, nsamples_here, tindices = [], [], [], []
for d, f, n, tib in zip(crosses.data, crosses.flags, crosses.nsamples, crosses.times_in_bins):
# find the indices of the times that are in this JD
tidx = np.argwhere(np.floor(tib).astype(int) == jd)
if len(tidx) > 0:
data_here.append(np.where(f[tidx[0][0]], np.nan, d[tidx[0][0]]))
flags_here.append(f[tidx[0][0]])
nsamples_here.append(n[tidx[0][0]])
tindices.append(tidx[0][0])
else:
data_here.append(np.full(d.shape[1:], np.nan))
flags_here.append(np.full(f.shape[1:], True))
nsamples_here.append(np.zeros(n.shape[1:], dtype=n.dtype))
tindices.append(None)
data_here, flags_here, nsamples_here = np.array(data_here), np.array(flags_here), np.array(nsamples_here)
for pol in crosses.hd.pols:
pidx = crosses.hd.pols.index(pol)
# get indices for indexing into autocorrelations for weights
p1, p2 = utils.split_pol(pol)
pidx1 = crosses.hd.pols.index(utils.join_pol(p1, p1))
pidx2 = crosses.hd.pols.index(utils.join_pol(p2, p2))
weights_here = np.where(flags_here[:, :, pidx] | ~np.isfinite(filtered_autos[:, :, pidx1]) | ~np.isfinite(filtered_autos[:, :, pidx2]),
0, nsamples_here[:, :, pidx] * filtered_autos[:, :, pidx1]**-1 * filtered_autos[:, :, pidx2]**-1)
tslices, fslices = flag_utils.get_minimal_slices(flags_here[:, :, pidx], freqs=hd.freqs, freq_cuts=[(FM_HIGH_FREQ + FM_LOW_FREQ) * .5e6])
for tslice, fslice in zip(tslices, fslices):
if (tslice is None) or (fslice is None):
continue
d_mdl, _, info = fourier_filter(bin_times[tslice] * 24 * 60 * 60,
np.where(weights_here[tslice, fslice] == 0, 0, data_here[tslice, fslice, pidx]),
wgts=weights_here[tslice, fslice],
filter_centers=[0],
filter_half_widths=[FR0_HALFWIDTH / 1000],
mode='dpss_solve',
eigenval_cutoff=[EIGENVAL_CUTOFF],
suppression_factors=[EIGENVAL_CUTOFF],
max_contiguous_edge_flags=len(bin_times[tslice]),
filter_dims=0)
data_here[tslice, fslice, pidx] -= d_mdl
# update data in crosses
for d, d_filt, tidx in zip(crosses.data, data_here, tindices):
if tidx is not None:
d[tidx, :, :] = d_filt
if FR0_FILTER and ((not IS_AUTO) or FR0_FILT_AUTOS):
jds = [int(jd) for jd in crosses.configurator.nights]
with ThreadPoolExecutor(max_workers=NUM_THREADS) as executor:
list(executor.map(_fr0_filter_one_jd, jds))
invalid value encountered in reciprocal
if FR0_FILTER:
lst_avg_fr0_filt_data, _, _ = crosses.average_over_nights()
gc_and_malloc_trim()
print_peak_rss("after FR=0 filter and averaging over nights")
[mem] after FR=0 filter and averaging over nights: current RSS = 8.22 GB, peak RSS = 37.94 GB, elapsed = 6.25 min
Save Results¶
add_to_history = 'This file was produced by single_baseline_lst_stack_and_reinpaint.ipynb notebook.\n'
add_to_history += 'The following conda environment was used:\n' + '=' * 65 + '\n' + os.popen('conda env export').read() + '=' * 65 + '\n'
add_to_history += f'The toml file {toml_file} was used:\n' + '=' * 65 + '\n' + toml.dumps(toml_options) + '=' * 65 + '\n'
add_to_history += 'The following files were stacked:\n' + '=' * 65 + '\n' + '\n'.join(configurator.bl_to_file_map[baseline_string]) + '\n' + '=' * 65 + '\n'
add_to_history += f'Averaged auto cache key: {cache_key}\n'
def _write_lst_avg_data(outfile, data, flags, nsamples):
'''Create a new UVData object using metadata from the first night.'''
uvd = UVData.new(freq_array=crosses.hd.freq_array,
polarization_array=[utils.polstr2num(p, x_orientation=crosses.hd.telescope.get_x_orientation_from_feeds())
for p in crosses.hd.pols],
times=bin_times,
telescope=crosses.hd.telescope,
antpairs=crosses.hd.antpairs,
vis_units=crosses.hd.vis_units,
empty=True)
uvd.data_array = data
uvd.flag_array = flags
uvd.pol_convention = pol_convention
uvd.nsample_array = nsamples
uvd.history = add_to_history + uvd.history
uvd.write_uvh5(outfile, clobber=True)
# write the lst-stacked and averaged data to a uvh5 file
_write_lst_avg_data(OUTFILE, lst_avg_data, lst_avg_flags, lst_avg_nsamples)
if FR0_FILTER:
_write_lst_avg_data(FR0_FILT_OUTFILE, lst_avg_fr0_filt_data, lst_avg_flags, lst_avg_nsamples)
gc_and_malloc_trim()
print_peak_rss("after write")
File exists; clobbering
File exists; clobbering
[mem] after write: current RSS = 8.22 GB, peak RSS = 37.94 GB, elapsed = 6.45 min
Metadata¶
for repo in ['hera_cal', 'hera_qm', 'hera_filters', 'hera_notebook_templates', 'pyuvdata', 'numpy']:
exec(f'from {repo} import __version__')
print(f'{repo}: {__version__}')
hera_cal: 3.7.8.dev38+g2607f380d hera_qm: 2.2.1.dev12+g95ecc30f0 hera_filters: 0.1.9.dev3+ged4deb46a
hera_notebook_templates: 0.0.1.dev1496+g4cc94af0b pyuvdata: 3.2.7.dev8+g5fca0c330 numpy: 2.3.5
print(f'Finished execution in {(time.time() - tstart) / 60:.2f} minutes.')
print_peak_rss("final")
Finished execution in 6.47 minutes. [mem] final: current RSS = 8.24 GB, peak RSS = 37.94 GB, elapsed = 6.47 min