Full Day RFI Flagging¶
by Josh Dillon, last updated August 1, 2025
This notebook is designed to figure out a single full-day RFI mask using the best autocorelations, taking individual file_calibration notebook results as a prior but then potentially undoing flags.
Here's a set of links to skip to particular figures and tables:
• Figure 1: Show All DPSS Residual z-Scores¶
• Figure 2: z-Score of DPSS-Filtered, Averaged Good Autocorrelation and Initial Flags¶
• Figure 3: z-Score of DPSS-Filtered, Averaged Good Autocorrelation and Expanded Flags¶
• Figure 4: z-Score of DPSS-Filtered, Averaged Good Autocorrelation and Final, Re-Computed Flags¶
• Figure 5: Summary of Flags Before and After Recomputing Them¶
In [1]:
import time
tstart = time.time()
In [2]:
import os
os.environ['HDF5_USE_FILE_LOCKING'] = 'FALSE'
import h5py
import hdf5plugin  # REQUIRED to have the compression plugins available
import numpy as np
import pandas as pd
import glob
import os
import matplotlib.pyplot as plt
import matplotlib
import copy
import warnings
import textwrap
from pyuvdata import UVFlag, UVData, UVCal
from hera_cal import io, utils, abscal
from hera_cal.smooth_cal import CalibrationSmoother, dpss_filters, solve_2D_DPSS
from hera_qm import ant_class, xrfi, metrics_io
from hera_filters import dspec
from IPython.display import display, HTML
%matplotlib inline
display(HTML("<style>.container { width:100% !important; }</style>"))
_ = np.seterr(all='ignore')  # get rid of red warnings
%config InlineBackend.figure_format = 'retina'
Parse inputs¶
In [3]:
# get filenames
SUM_FILE = os.environ.get("SUM_FILE", None)
# SUM_FILE = '/lustre/aoc/projects/hera/h6c-analysis/IDR2/2459866/zen.2459866.25282.sum.uvh5'  # If sum_file is not defined in the environment variables, define it here.
SUM_SUFFIX = os.environ.get("SUM_SUFFIX", 'sum.uvh5')
SUM_AUTOS_SUFFIX = os.environ.get("SUM_AUTOS_SUFFIX", 'sum.autos.uvh5')
DIFF_AUTOS_SUFFIX = os.environ.get("DIFF_AUTOS_SUFFIX", 'diff.autos.uvh5')
CAL_SUFFIX = os.environ.get("CAL_SUFFIX", 'sum.omni.calfits')
ANT_CLASS_SUFFIX = os.environ.get("ANT_CLASS_SUFFIX", 'sum.ant_class.csv')
APRIORI_YAML_PATH = os.environ.get("APRIORI_YAML_PATH", None)
OUT_FLAG_SUFFIX = os.environ.get("OUT_FLAG_SUFFIX", 'sum.flag_waterfall.h5')
sum_glob = '.'.join(SUM_FILE.split('.')[:-3]) + '.*.' + SUM_SUFFIX
auto_sums_glob = sum_glob.replace(SUM_SUFFIX, SUM_AUTOS_SUFFIX)
auto_diffs_glob = sum_glob.replace(SUM_SUFFIX, DIFF_AUTOS_SUFFIX)
cal_files_glob = sum_glob.replace(SUM_SUFFIX, CAL_SUFFIX)
ant_class_csvs_glob = sum_glob.replace(SUM_SUFFIX, ANT_CLASS_SUFFIX)
In [4]:
# A priori flag settings
FM_LOW_FREQ = float(os.environ.get("FM_LOW_FREQ", 87.5)) # in MHz
FM_HIGH_FREQ = float(os.environ.get("FM_HIGH_FREQ", 108.0)) # in MHz
FM_freq_range = [FM_LOW_FREQ * 1e6, FM_HIGH_FREQ * 1e6]
MAX_SOLAR_ALT = float(os.environ.get("MAX_SOLAR_ALT", 0.0)) # in degrees
PER_POL_FILE_FLAG_THRESH = float(os.environ.get("PER_POL_FILE_FLAG_THRESH", .75)) 
# DPSS settings
FREQ_FILTER_SCALE = float(os.environ.get("FREQ_FILTER_SCALE", 5.0)) # in MHz
TIME_FILTER_SCALE = float(os.environ.get("TIME_FILTER_SCALE", 450.0))# in s
EIGENVAL_CUTOFF = float(os.environ.get("EIGENVAL_CUTOFF", 1e-12))
# Outlier flagging settings
MIN_FRAC_OF_AUTOS = float(os.environ.get("MIN_FRAC_OF_AUTOS", .25))
MAX_AUTO_L2 = float(os.environ.get("MAX_AUTRO_L2", 1.2))
Z_THRESH = float(os.environ.get("Z_THRESH", 5.0))
WS_Z_THRESH = float(os.environ.get("WS_Z_THRESH", 4.0))
AVG_Z_THRESH = float(os.environ.get("AVG_Z_THRESH", 1.5))
REPEAT_FLAG_Z_THRESH = float(os.environ.get("REPEAT_FLAG_Z_THESH", 0.0))
MAX_FREQ_FLAG_FRAC = float(os.environ.get("MAX_FREQ_FLAG_FRAC", .25))
MAX_TIME_FLAG_FRAC = float(os.environ.get("MAX_TIME_FLAG_FRAC", .1))
for setting in ['FM_LOW_FREQ', 'FM_HIGH_FREQ', 'MAX_SOLAR_ALT', 'PER_POL_FILE_FLAG_THRESH', 
                'FREQ_FILTER_SCALE', 'TIME_FILTER_SCALE', 'EIGENVAL_CUTOFF', 'MIN_FRAC_OF_AUTOS', 
                'MAX_AUTO_L2', 'Z_THRESH', 'WS_Z_THRESH', 'AVG_Z_THRESH', 'REPEAT_FLAG_Z_THRESH', 
                'MAX_FREQ_FLAG_FRAC ', 'MAX_TIME_FLAG_FRAC ']:
        print(f'{setting} = {eval(setting)}')
FM_LOW_FREQ = 87.5 FM_HIGH_FREQ = 108.0 MAX_SOLAR_ALT = 0.0 PER_POL_FILE_FLAG_THRESH = 0.75 FREQ_FILTER_SCALE = 5.0 TIME_FILTER_SCALE = 450.0 EIGENVAL_CUTOFF = 1e-12 MIN_FRAC_OF_AUTOS = 0.25 MAX_AUTO_L2 = 1.2 Z_THRESH = 5.0 WS_Z_THRESH = 4.0 AVG_Z_THRESH = 1.5 REPEAT_FLAG_Z_THRESH = 0.0 MAX_FREQ_FLAG_FRAC = 0.25 MAX_TIME_FLAG_FRAC = 0.1
Load Data¶
In [5]:
auto_sums = sorted(glob.glob(auto_sums_glob))
print(f'Found {len(auto_sums)} *.{SUM_AUTOS_SUFFIX} files starting with {auto_sums[0]}.')
auto_diffs = sorted(glob.glob(auto_diffs_glob))
print(f'Found {len(auto_diffs)} *.{DIFF_AUTOS_SUFFIX} files starting with {auto_diffs[0]}.')
cal_files = sorted(glob.glob(cal_files_glob))
print(f'Found {len(cal_files)} *.{CAL_SUFFIX} files starting with {cal_files[0]}.')
ant_class_csvs = sorted(glob.glob(ant_class_csvs_glob))
print(f'Found {len(ant_class_csvs)} *.{ANT_CLASS_SUFFIX} files starting with {ant_class_csvs[0]}.')
Found 1850 *.sum.autos.uvh5 files starting with /lustre/aoc/projects/hera/h6c-analysis/IDR3/2459883/zen.2459883.25268.sum.autos.uvh5. Found 1850 *.diff.autos.uvh5 files starting with /lustre/aoc/projects/hera/h6c-analysis/IDR3/2459883/zen.2459883.25268.diff.autos.uvh5. Found 1850 *.sum.omni.calfits files starting with /lustre/aoc/projects/hera/h6c-analysis/IDR3/2459883/zen.2459883.25268.sum.omni.calfits. Found 1850 *.sum.ant_class.csv files starting with /lustre/aoc/projects/hera/h6c-analysis/IDR3/2459883/zen.2459883.25268.sum.ant_class.csv.
In [6]:
# Load ant_class csvs
tables = [pd.read_csv(f).dropna(axis=0, how='all') for f in ant_class_csvs]
table_cols = tables[0].columns[1::2]
class_cols = tables[0].columns[2::2]
In [7]:
# set up for for figuring out candidate antennas
ap_strs = np.array(tables[0]['Antenna'])
ant_flags = np.array([t[class_cols] for t in tables]) == 'bad'
sun_low_enough = np.array([t['Solar Alt'] < MAX_SOLAR_ALT for t in tables])
ants = sorted(set(int(a[:-1]) for a in ap_strs))
# get relevant indices (exclude antennas only flagged for Even/Odd Zeros or Redcal chi^2 or Bad X-Engine Diffs)
e_pols = [i for i, ap_str in enumerate(ap_strs) if 'e' in ap_str]
n_pols = [i for i, ap_str in enumerate(ap_strs) if 'n' in ap_str]
cols_to_use = [cc for cc, colname in enumerate(class_cols) if colname not in 
               ['Antenna Class', 'Even/Odd Zeros Class','Redcal chi^2 Class', 'Bad Diff X-Engines Class']]
# perfrom any over flagging rationales, excluding times where the sun is too high
passes_checks_grid = np.any(ant_flags[:, :, cols_to_use] & sun_low_enough[:, :, None], axis=2)
# also exclude nearly fully-flagged files
files_to_flag = np.mean(passes_checks_grid[:, e_pols], axis=1) > PER_POL_FILE_FLAG_THRESH
files_to_flag |= np.mean(passes_checks_grid[:, n_pols], axis=1) > PER_POL_FILE_FLAG_THRESH
print(f'Found {int(np.sum(files_to_flag))} files to fully flag based on one pol exceeding {PER_POL_FILE_FLAG_THRESH:.2%}')
is_candidate_auto = ~np.any(passes_checks_grid[~files_to_flag, :], axis=0)
# get set of candidate autocorrelation keys
candidate_autos = set()
for ap_str in ap_strs[is_candidate_auto]:
    ap = int(ap_str[:-1]), utils.comply_pol(ap_str[-1])
    candidate_autos.add(utils.join_bl(ap, ap))
print(f'{len(candidate_autos)} candidate autocorrelations identified for RFI flagging.')
Found 1 files to fully flag based on one pol exceeding 75.00% 76 candidate autocorrelations identified for RFI flagging.
In [8]:
# Load sum and diff autos, checking to see whether any of them show packet loss
good_data = {}
info_dicts = {}
new_fully_flagged_files = 0
for i, (sf, df, f2f) in enumerate(zip(auto_sums, auto_diffs, files_to_flag)):
    if f2f:
        continue
    rv = io.read_hera_hdf5(sf, bls=candidate_autos)
    good_data[sf] = rv['data']
    info_dicts[sf] = rv['info']
    diff = io.read_hera_hdf5(df, bls=candidate_autos)['data']
    zeros_class = ant_class.even_odd_zeros_checker(good_data[sf], diff)
    
    # if this file is fully flagged, don't let that affect the candidates
    if len(zeros_class.bad_ants) == len(zeros_class.ants):
        new_fully_flagged_files += 1
        files_to_flag[i] = True
        continue    
    
    for ant in zeros_class.bad_ants:
        candidate_autos.remove(utils.join_bl(ant, ant))
        print(f'Removing {utils.join_bl(ant, ant)} on {sf}, {len(candidate_autos)} remain.')
    
if new_fully_flagged_files > 0:
    print(f'{new_fully_flagged_files} additional files were flagged because they were 100% flagged for Even/Odd zeros.')
print(f'{len(candidate_autos)} candidate autocorrelations remain after looking for packet loss effects in autos.')
76 candidate autocorrelations remain after looking for packet loss effects in autos.
In [9]:
# load calibration solutions
cs = CalibrationSmoother(cal_files, load_cspa=False, load_chisq=False, pick_refant=False)
In [10]:
# load a priori flagged times
if APRIORI_YAML_PATH is not None:
    print(f'Loading a priori flagged times from {APRIORI_YAML_PATH}')
    apriori_flags = np.zeros(len(cs.time_grid), dtype=bool)
    apriori_flags[metrics_io.read_a_priori_int_flags(APRIORI_YAML_PATH, times=cs.time_grid).astype(int)] = True
Loading a priori flagged times from /lustre/aoc/projects/hera/h6c-analysis/IDR3/src/hera_pipelines/pipelines/h6c/idr3/v1/analysis/apriori_flags/2459883_apriori_flags.yaml
In [11]:
# completely flag times that had too many antennas flagged
for f2f in np.array(sorted(cs.time_indices.keys()))[files_to_flag]:
    for ant in cs.flag_grids:
        cs.flag_grids[ant][cs.time_indices[f2f], :] = True
Figure out a subset of most-stable antennas to filter and flag on¶
In [12]:
initial_cal_flags = np.all([f for f in cs.flag_grids.values()], axis=0)
In [13]:
def average_autos(per_file_autos, files_to_flag, bls_to_use, auto_sums, cs):
    '''Averages autos over baselines, matching the time_grid in CalibrationSmoother cs.'''
    avg_per_file_autos = {sf: None if f2f else np.mean([per_file_autos[sf][bl] for bl in bls_to_use], axis=0)
                          for sf, f2f in zip(auto_sums, files_to_flag)}
    avg_autos = np.zeros((len(cs.time_grid), len(cs.freqs)), dtype=float)
    for sf, cf in zip(auto_sums, cs.cals):
        if avg_per_file_autos[sf] is not None:  # because the file was flagged
            avg_autos[cs.time_indices[cf], :] = np.abs(avg_per_file_autos[sf])
    return avg_autos
In [14]:
avg_candidate_auto = average_autos(good_data, files_to_flag, candidate_autos, auto_sums, cs)
In [15]:
def flag_FM(flags, freqs, freq_range=[87.5e6, 108e6]):
    '''Apply flags to all frequencies within freq_range (in Hz).'''
    flags[:, np.logical_and(freqs >= freq_range[0], freqs <= freq_range[1])] = True 
In [16]:
flag_FM(initial_cal_flags, cs.freqs, freq_range=FM_freq_range)
In [17]:
def flag_sun(flags, times, max_solar_alt=0):
    '''Apply flags to all times where the solar altitude is greater than max_solar_alt (in degrees).'''
    solar_altitudes_degrees = utils.get_sun_alt(times)
    flags[solar_altitudes_degrees >= max_solar_alt, :] = True
In [18]:
flag_sun(initial_cal_flags, cs.time_grid, max_solar_alt=MAX_SOLAR_ALT)
In [19]:
if APRIORI_YAML_PATH is not None:
    initial_cal_flags[apriori_flags, :] = True
In [20]:
def predict_auto_noise(auto, dt, df, nsamples=1):
    '''Predict noise on an (antenna-averaged) autocorrelation. The product of Delta t and Delta f
    must be unitless. For N autocorrelations averaged together, use nsamples=N.'''
    int_count = int(dt * df) * nsamples
    return np.abs(auto) / np.sqrt(int_count / 2)
In [21]:
# Figure out noise and weights
int_time = 24 * 3600 * np.median(np.diff(cs.time_grid))
chan_res = np.median(np.diff(cs.freqs))
noise = predict_auto_noise(avg_candidate_auto, int_time, chan_res, nsamples=1)
wgts = np.where(initial_cal_flags, 0, noise**-2)
In [22]:
# get slices to index into region of waterfall outwide of which it's 100% flagged
unflagged_ints = np.squeeze(np.argwhere(~np.all(initial_cal_flags, axis=1)))
ints_to_filt = slice(unflagged_ints[0], unflagged_ints[-1] + 1)
unflagged_chans = np.squeeze(np.argwhere(~np.all(initial_cal_flags, axis=0)))
chans_to_filt = slice(unflagged_chans[0], unflagged_chans[-1] + 1)
In [23]:
# Filter every autocorrelation individually
cached_output = {}
models = {}
sqrt_mean_sqs = {}
time_filters, freq_filters = dpss_filters(freqs=cs.freqs[chans_to_filt], # Hz
                                          times=cs.time_grid[ints_to_filt], # JD
                                          freq_scale=FREQ_FILTER_SCALE,
                                          time_scale=TIME_FILTER_SCALE,
                                          eigenval_cutoff=EIGENVAL_CUTOFF)
for bl in candidate_autos:
    auto_here = average_autos(good_data, files_to_flag, [bl], auto_sums, cs)
    models[bl] = np.array(auto_here)
    model, cached_output = solve_2D_DPSS(auto_here[ints_to_filt, chans_to_filt], wgts[ints_to_filt, chans_to_filt], 
                                         time_filters, freq_filters, method='lu_solve', cached_input=cached_output)
    models[bl][ints_to_filt, chans_to_filt] = model
    
    noise_model = predict_auto_noise(models[bl], int_time, chan_res, nsamples=1)   
    sqrt_mean_sqs[bl] = np.nanmean(np.where(initial_cal_flags, np.nan, (auto_here - models[bl]) / noise_model)**2)**.5
WARNING:2025-08-31 21:42:04,675:jax._src.xla_bridge:966: An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.
In [24]:
# Pick best autocorrelations to filter on
L2_bound = max(np.quantile(list(sqrt_mean_sqs.values()), MIN_FRAC_OF_AUTOS), MAX_AUTO_L2)
good_auto_bls = [bl for bl in candidate_autos if sqrt_mean_sqs[bl] <= L2_bound]
print(f'Using {len(good_auto_bls)} out of {len(candidate_autos)} candidate autocorrelations ({len(good_auto_bls) / len(candidate_autos):.2%}).') 
Using 28 out of 76 candidate autocorrelations (36.84%).
In [25]:
extent = [cs.freqs[0]/1e6, cs.freqs[-1]/1e6, cs.time_grid[-1] - int(cs.time_grid[0]), cs.time_grid[0] - int(cs.time_grid[0])]
In [26]:
def plot_all_filtered_bls(N_per_row=8):
    N_rows = int(np.ceil(len(candidate_autos) / N_per_row))
    fig, axes = plt.subplots(N_rows, N_per_row, figsize=(14, 3 * N_rows), dpi=100,
                             sharex=True, sharey=True, gridspec_kw={'wspace': 0, 'hspace': .18})
    for i, (ax, bl) in enumerate(zip(axes.flatten(), sorted(sqrt_mean_sqs.keys(), key=lambda bl: sqrt_mean_sqs[bl]))):
        auto_here = average_autos(good_data, files_to_flag, [bl], auto_sums, cs)
        noise_model = predict_auto_noise(models[bl], int_time, chan_res, nsamples=1)
        im = ax.imshow(np.where(initial_cal_flags, np.nan, (auto_here - models[bl]) / noise_model).real, 
                       aspect='auto', interpolation='none', cmap='bwr', vmin=-10, vmax=10, extent=extent)
        ax.set_title(f'{bl[0]}{bl[2][0]}: {sqrt_mean_sqs[bl]:.3}', color=('k' if sqrt_mean_sqs[bl] <= L2_bound else 'r'), fontsize=10)
        if i == 0:
            plt.colorbar(im, ax=axes, location='top', label=r'Autocorrelation z-score after DPSS filtering (with $\langle z^2 \rangle^{1/2}$)', extend='both', aspect=40, pad=.015)
        if i % N_per_row == 0:
            ax.set_ylabel(f'JD - {int(cs.time_grid[0])}')       
    for ax in axes[-1, :]:
        ax.set_xlabel('Frequency (MHz)')
    plt.tight_layout()
    plt.show()
    
    antpols = [(int(ap[:-1]), utils.comply_pol(ap[-1])) for ap in ap_strs]
    other_autos = [f'{ap[0]}{ap[-1][-1]}' for ap in antpols if utils.join_bl(ap, ap) not in candidate_autos]
    print('Not plotted here due to prior antenna flagging:')
    print('\t' + '\n\t'.join(textwrap.wrap(', '.join(other_autos), 80, break_long_words=False)))
Figure 1: Show All DPSS Residual z-Scores¶
This figure shows the z-score waterfall of each antenna. Also shown is the square root of the mean of the square of each waterfall, as a metric of its instability. Antennas in red are excluded from the average of most stable antennas that are used for subsequent flagging.
In [27]:
plot_all_filtered_bls()
This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.