LST-Bin¶
by Steven Murray, Tyler Cox and Josh Dillon, last updated 17th June, 2025.
This notebook performs LST-binning, producing a single output file. The input to this notebook consists of two configuration files, and one index:
- A fileconf, which is produced byhera_cal.lstbin_simple.make_lst_bin_config_file()run over a set of raw files. This file lists all the raw files that correspond to all the particular bins, which makes it quick for this notebook to read them in.
- A binning configuration file, config, that specifies all the parameters to use when performing the binning itself.
- The file index that corresponds to the LST bins that will be saved to the output file in this notebook.
The notebook then proceeds to do essentially the same thing as hera_cal.lstbin_simple.lst_bin_files_single_outfile, but with extra plotting and inspection stops along the way.
Imports¶
import os
os.environ['HDF5_USE_FILE_LOCKING'] = 'FALSE'
import h5py
import hdf5plugin
import sys
import pickle
import itertools
from pathlib import Path
from functools import partial
from datetime import datetime
from time import time as _time
import resource
from collections import UserDict
import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt
import matplotlib as mpl
import attrs
from scipy import linalg, constants, signal
from hera_filters import dspec
from pyuvdata import UVData
from hera_cal.abscal import complex_phase_abscal
from hera_cal import lst_stack as lstbin
from hera_cal.lst_stack.config import LSTConfig
from hera_cal.red_groups import RedundantGroups
from hera_cal.lst_stack import metrics as lstmet
from hera_cal.lst_stack import stats as lststat
from hera_cal.lst_stack.calibration import lstbin_absolute_calibration, _expand_degeneracies_to_ant_gains
from hera_cal.lst_stack import averaging as avg
from hera_cal import vis_clean, redcal, abscal, utils
from hera_qm.time_series_metrics import true_stretches
from hera_notebook_templates.utils import parse_band_str
import importlib
start_time = _time()
Configuration¶
fileconf: str = "/lustre/aoc/projects/hera/Validation/H6C_IDR2/lstbin-outputs/redavg-smoothcal-inpaint-500ns-lstcal/file-config.h5"
fileidx: int = 360
blchunk: int = 0
papermill_input_path: str = ""
papermill_output_path: str = ""
# The following are defaults that can be overwritten at execution time (preferably by a YAML file)
save_lstbin_data: bool = False
save_metric_data: bool = False
plot_n_worst: int = 5
inpaint_method: str = "per-night"  # one of 'simultaneous', 'per-night', 'pre-inpainted' or 'none'
do_extra_flagging: bool = False
# LST-cal config
do_lstcal: bool = True
lstcal_path: str = None
smoothing_scale: float = 10e6 # smoothing scale in Hz
run_phase_cal: bool = True
run_amplitude_cal: bool = True
run_cross_pol_phase_cal: bool = True
use_inpainted_data: bool = False
outdir: str = Path("~/lststack-outputs").expanduser()
bl_chunk_size: int = 0
rephase: bool = True
fname_format: str = '{inpaint_mode}/zen.{kind}.{lst:7.5f}.{blchunk}.sum.uvh5'
overwrite: bool = True
write_med_mad: bool = False
freq_min: float = 0.0
freq_max: float = 0.0
history: str = ""
plot_every: int = 1
exception_for_zsq_above: float = 1e20  # Raise an exception if any baseline has a zsq greater than this. Useful for flagging weird behaviour for validation.
# In-painting config
inpaint_horizon: float = 1.0
inpaint_standoff: float = 0.0    # ns
inpaint_eigencutoff: float = 1e-12
inpaint_mindelay: float = 500.0  # ns
inpaint_max_gap_factor: float = 1.0
inpaint_max_convolved_flag_frac: float = 0.4
inpaint_sample_cov_fraction: float = 0.0 # Default zero uses variance, one uses full sample covariance
inpaint_use_unbiased_estimator: bool = False # Default False slight over estimate of the covariance, but guaranteed to be non-negative
inpaint_spw_buffer_size: int = 2  # Size of buffer on each edge of each spw where flag-gaps don't kill the spw
spws: str = "50.1~62.2,63.3~73.5,74.6~85.4,108.0~116.1,117.3~124.4,125.3~136.2,138.3~148.2,150.1~159.2,159.3~169.9,171.9~181.1,181.4~196.4,198.5~208.4,212.3~220.6,224.4~231.1"
fm_low_freq: float = 87.5 # in MHz
fm_high_freq: float = 108.0 # in MHz
# Flagging Configuration
zscore_threshold: float = 5              # Value of |Z| above which data are flagged. 
iterative_flagging_factor: float = 1.5   # When flagging on |Z|^2, the worst offender (W) is flagged, and any other offenders > W/iterative_flagging_factor
watershed_threshold: float = 3           # Value of |Z| above which data surrounding other flagged data will be flagged.     
max_flagging_iterations: int = 15             # Maximum number of iterations to perform when flagging.
# Parameters
fileconf = "/lustre/aoc/projects/hera/h6c-analysis/IDR2/lstbin-outputs/nonavg-smoothcal-inpaint-500ns-lstcal/file-config.h5"
fileidx = 0
blchunk = 3
max_flagging_iterations = 15
watershed_threshold = 3.0
iterative_flagging_factor = 1.5
zscore_threshold = 5.0
fm_high_freq = 108.0
fm_low_freq = 87.5
spws = "50.1~62.2,63.3~73.5,74.6~85.4,108.0~116.1,117.3~124.4,125.3~136.2,138.3~148.2,150.1~159.2,159.3~169.9,171.9~181.1,181.4~196.4,198.5~208.4,212.3~220.6,224.4~231.1"
inpaint_spw_buffer_size = 2
inpaint_use_unbiased_estimator = False
inpaint_sample_cov_fraction = 0.0
inpaint_max_convolved_flag_frac = 0.4
inpaint_max_gap_factor = 1.0
inpaint_mindelay = 500
inpaint_eigencutoff = 1e-12
inpaint_standoff = 0.0
inpaint_horizon = 1.0
exception_for_zsq_above = 1e20
plot_every = 180
history = ""
freq_max = 0.0
freq_min = 0.0
write_med_mad = False
overwrite = True
fname_format = "{inpaint_mode}/zen.{kind}.{lst:7.5f}.{blchunk}.sum.uvh5"
rephase = True
bl_chunk_size = 2000
outdir = "/lustre/aoc/projects/hera/h6c-analysis/IDR2/lstbin-outputs/nonavg-smoothcal-inpaint-500ns-lstcal/"
use_inpainted_data = False
run_cross_pol_phase_cal = True
run_amplitude_cal = True
run_phase_cal = True
smoothing_scale = 10000000.0
lstcal_path = "/lustre/aoc/projects/hera/h6c-analysis/IDR2/lstbin-outputs/redavg-smoothcal-inpaint-500ns-lstcal"
do_lstcal = True
do_extra_flagging = False
inpaint_method = "per-night"
plot_n_worst = 5
save_metric_data = True
save_lstbin_data = True
papermill_output_path = "/lustre/aoc/projects/hera/h6c-analysis/IDR2/lstbin-outputs/nonavg-smoothcal-inpaint-500ns-lstcal/lststack.0000.003.ipynb"
papermill_input_path = "/lustre/aoc/projects/hera/h6c-analysis/IDR2/src/hera_notebook_templates/hera_notebook_templates/notebooks/lststack.ipynb"
# Parameter changes for typing
outdir = Path(outdir)
if freq_max <= 0.0:
    freq_max = None
if freq_min <= 0.0:
    freq_min = None
if bl_chunk_size <= 0:
    bl_chunk_size = None
if "{blchunk}" not in fname_format:
    raise ValueError("The fname_format is missing a {blchunk} format option")
make_plots = (fileidx % plot_every) == 0
if make_plots or save_metric_data:
    get_metrics = True
else:
    get_metrics = False
fileconf = Path(fileconf)
assert fileconf.exists() and fileconf.is_file(), "The input file-configuration file is not a file"
stackconf = LSTConfig.from_file(fileconf)
print("The LST grid was configured with these parameters: \n")
for key, val in attrs.asdict(stackconf.config).items():
    if key != 'data_files':
        print(f"  {key:>36}: {val}")
The LST grid was configured with these parameters: 
                        nlsts_per_file: 2
                                  dlst: 0.0007047089846545072
                                  atol: 1e-10
                             lst_start: 0.0
                               lst_end: 6.283185307179586
                              jd_regex: zen\.(\d+\.\d+)\.
                         calfile_rules: [('/lustre/aoc/projects/hera/H6C/', '/lustre/aoc/projects/hera/h6c-analysis/IDR2/'), ('.uvh5', '.smooth.calfits')]
            where_inpainted_file_rules: None
                           ignore_ants: ()
    antpairs_from_last_file_each_night: True
print("The raw files have the following properties: \n")
for key, val in stackconf.properties.items():
    print(f"  {key:>25}: {val}")
The raw files have the following properties: 
       blts_are_rectangular: True
                   first_jd: 2459861.2529114624
             lst_branch_cut: 5.406879684761707
  time_axis_faster_than_bls: True
              x_orientation: north
stackconf = stackconf.at_single_outfile(fileidx)
print(f"Raw files used in this notebook (for all bins): \n")
for fl in stackconf.matched_files:
    print(fl.name)
Raw files used in this notebook (for all bins): zen.2459861.39188.sum.uvh5 zen.2459861.39211.sum.uvh5 zen.2459862.38931.sum.uvh5 zen.2459863.38640.sum.uvh5 zen.2459863.38663.sum.uvh5 zen.2459864.38366.sum.uvh5 zen.2459864.38388.sum.uvh5 zen.2459866.37841.sum.uvh5 zen.2459867.37568.sum.uvh5 zen.2459868.37294.sum.uvh5 zen.2459869.37021.sum.uvh5 zen.2459870.36747.sum.uvh5 zen.2459871.36472.sum.uvh5 zen.2459872.36191.sum.uvh5 zen.2459872.36213.sum.uvh5 zen.2459873.35924.sum.uvh5 zen.2459874.35651.sum.uvh5 zen.2459876.35104.sum.uvh5
print(f"The data has {len(stackconf.autopairs + stackconf.antpairs)} ant-pairs, and {stackconf.pols} polarizations.")
The data has 16836 ant-pairs, and ['ee', 'en', 'ne', 'nn'] polarizations.
outdir = Path(outdir)
if not outdir.exists():
    outdir.mkdir(parents=True, exist_ok=True)
print(f"Writing output files to: \n  {outdir}")
Writing output files to: /lustre/aoc/projects/hera/h6c-analysis/IDR2/lstbin-outputs/nonavg-smoothcal-inpaint-500ns-lstcal
if inpaint_method not in ['simultaneous', 'pre-inpainted', 'per-night', 'none']:
    raise ValueError("inpaint_method must be one of 'simultaneous', 'pre-inpainted', 'per-night' or 'none'")
    
if inpaint_method == 'pre-inpainted' and stackconf.inpaint_files is None:
    raise ValueError("Cannot do per-night inpainted average without inpainted files")
# Split up the baselines into chunks that will be LST-binned together.
# This is just to save on RAM.
if bl_chunk_size is None:
    bl_chunk_size = len(stackconf.antpairs)
else:
    bl_chunk_size = min(bl_chunk_size, len(stackconf.antpairs))
n_bl_chunks = int(np.ceil(len(stackconf.antpairs) / bl_chunk_size))
print(f"This notebook is processing chunk {blchunk+1} of {n_bl_chunks} baseline chunks, each with {bl_chunk_size} baselines.")
This notebook is processing chunk 4 of 9 baseline chunks, each with 2000 baselines.
if lstcal_path is None and n_bl_chunks > 1 and do_lstcal:
    raise ValueError(
        "Cannot run LSTCal when only a subset of the baselines are loaded and lstcal_path is not provided"
    )
reds_with_pols = RedundantGroups.from_antpos(
    antpos={i: pos for i, pos in enumerate(stackconf.config.datameta.antpos_enu)}, 
    pols=stackconf.pols,
    bl_error_tol=2.0,
)
def print_metadata():
    # A function that prints metadata about the notebook
    print("Software Versions Used: ")
    for repo in ['numpy', 'scipy', 'astropy', 'hera_cal', 'hera_qm', 'hera_filters', 'hera_notebook_templates', 'pyuvdata']:
        mdl = importlib.import_module(repo)
        print(f'{repo:>25}: {mdl.__version__}')
        
    print("Run by: ", end='')
    os.system("whoami");
    print(f"Run on {datetime.now()}")
    print(f"Execution of notebook took: {(_time() - start_time)/60.0:.2f} minutes")
    print(f"Peak memory in this notebook run: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1024**2:.2f} GB")
if len(set(sum((x.tolist() for x in stackconf.time_indices), start=[]))) != stackconf.n_lsts:
    print("LST-Stacking for files where not all of the LST-bins have associated data is not yet supported.")
    print_metadata()
    sys.exit(0)
# If this notebook is making plots, and is being run through PAPERMILL, output an empty
# file that tells the execution script to save a copy of the output notebook to the
# public-facing notebook directory.
if papermill_output_path and make_plots:
    pth = Path(f"{papermill_output_path}.hasplots")
    pth.touch()
data_is_redundantly_averaged = stackconf.config.is_redundantly_averaged
print(f"This data {'is' if data_is_redundantly_averaged else 'is not'} redundantly averaged.")
This data is not redundantly averaged.
# A partial function that makes it simpler to create files of the correct format.
get_fname = partial(
    lstbin.io.format_outfile_name,
    lst=stackconf.lst_grid_edges[0],
    pols=stackconf.pols,
    inpaint_mode=True,
    lst_branch_cut=stackconf.properties["lst_branch_cut"],
    fname_format=fname_format.replace("{blchunk}", "{blchunk:03}"),
    blchunk=blchunk
)
# Specify the SPW's used for per-night inpainting (and flagging)
if spws is None:
    spws = (slice(0, None, None),)
else:
    f = stackconf.config.datameta.freq_array
    _, _, _, _, spws, _ = parse_band_str(spws, f)
Define Stacking/Averaging Functions¶
Define and initialize the output files that we will write in this notebook:
Now, define a function that uses the configuration we've established and performs LST-binning for a subset of baselines.
def stack_blchunk(bl_chunk: int | str, inpainted_mode: str):
    """Process a single chunk of baselines."""
    stacks: list[UVData] = lstbin.binning.lst_bin_files_from_config(
        stackconf,
        bl_chunk_to_load=bl_chunk,
        nbl_chunks=n_bl_chunks,
    )
    rdcs = []
    for lstidx, stack in enumerate(stacks):
        rdc = lstbin.averaging.reduce_lst_bins(
            lststack=stack,
            inpainted_mode=inpainted_mode == 'pre-inpainted',
            get_mad=write_med_mad and inpaint_method!='none',  # MED/MAD gotten later if doing inpainting
        )
        rdcs.append(rdc)
                     
    return stacks, rdcs
Plotting Style Setup¶
data_jd_ints = sorted({int(meta.times[0]) for meta in stackconf.matched_metas})
styles = {}
for i, jdint in enumerate(data_jd_ints):
    styles[jdint] = {'color': f"C{i%10}", 'ls': ['-', '--', ':', '-.'][i//10]}
Define Subsets of Data to Consider¶
Bands¶
bands_considered = [
    (0, 200), (200, 400), (400, 600), (600, 800), (800, 1000), (1000, 1200), (1200, 1400), (1400, 1536),
    (0, 450),    # low band
    (450, 1536), # high band
    (0, 1536),   # full band
]
Baselines¶
def get_all_antenna_sectors():
    antpos = stackconf.config.datameta.antenna_positions
    zero_pos = np.mean([antpos[165], antpos[166], antpos[145]], axis=0)
    
    sectors = {}
    for ant, pos in enumerate(antpos):
        rec = pos - zero_pos
        theta = np.arctan2(rec[1], rec[0])
        bllen = np.sqrt(rec[0]**2 + rec[1]**2)
        if bllen > 200:
            sectors[ant] = 4  # outrigger
        elif -np.pi / 3 <= theta < np.pi / 3:
            sectors[ant] = 1
        elif np.pi / 3 <= theta < np.pi:
            sectors[ant] = 2
        elif -np.pi <= theta < -np.pi/3:
            sectors[ant] = 3
    return sectors
sectors = get_all_antenna_sectors()
def getblvec(a, b):
    if hasattr(auto_stacks[0].telescope, 'antenna_positions'):
        # for pyuvdata 3.2+
        antpos = auto_stacks[0].telescope.antenna_positions 
    else:
        # for older pyuvdata
        antpos = auto_stacks[0].antenna_positions 
    return antpos[a] - antpos[b]
def getbllen(a,b):
    return np.sqrt(np.sum(np.square(getblvec(a,b))))
all_ee = lambda bl: bl[2] == 'ee'
all_nn = lambda bl: bl[2] == 'nn'
short_bls = lambda bl: getbllen(bl[0], bl[1])<=60.0 and len(set(bl[2]))==1
long_bls = lambda bl: getbllen(bl[0], bl[1])>60.0 and len(set(bl[2]))==1
intersector_bls = lambda bl: sectors[bl[0]] != sectors[bl[1]] and len(set(bl[2]))==1
intrasector_bls = lambda bl: sectors[bl[0]] == sectors[bl[1]] and len(set(bl[2]))==1
subsets = {
    'all': lambda bl: True,
    'ee-only': all_ee,
    'nn-only': all_nn,
    'Short (<60 m) baselines': short_bls,
    'Long (>60 m) baselines': long_bls,
    'Inter-sector baselines': intersector_bls,
    "Intra-sector baselines": intrasector_bls,
}
inpaint_bands = [(0, fm_low_freq), (fm_high_freq, np.inf)]  # default below and above FM
# Get slices for the inpaint bands
_inp = []
for _bnd in inpaint_bands:
    idx = np.nonzero((stackconf.config.datameta.freq_array >= _bnd[0] * 1e6) & (stackconf.config.datameta.freq_array < _bnd[1]*1e6))[0]
    _inp.append(slice(idx[0], idx[-1] + 1))
inpaint_bands = _inp
print("Using the following bands for inpainting (channels):")
for bnd in inpaint_bands:
    print(bnd)
Using the following bands for inpainting (channels): slice(np.int64(0), np.int64(333), None) slice(np.int64(501), np.int64(1536), None)
Perform Initial Stacking of Autos and Crosses¶
auto_stacks, autos_lstavg = stack_blchunk('autos', inpainted_mode=inpaint_method) # Auto-stacks
cross_stacks, cross_lstavg = stack_blchunk(blchunk, inpainted_mode=inpaint_method) # Cross-stacks
LST-Bin Calibration¶
%%time
if do_lstcal and n_bl_chunks==1: # can't fit all the baselines in memory if not redavg'd, and need all of them at once to do lstcal
    all_calibration_parameters = {}
    for i, (stack, lstavg_model, auto_model) in enumerate(zip(cross_stacks, cross_lstavg, autos_lstavg)):
        calibration_parameters, gains = lstbin_absolute_calibration(
            stack=stack, 
            model=lstavg_model['data'], 
            all_reds=reds_with_pols, 
            inpaint_bands=inpaint_bands,
            auto_stack=auto_stacks[i],
            auto_model=auto_model['data'],
            calibrate_inplace=True, # calibrate inplace
            run_amplitude_cal=run_amplitude_cal, # run amplitude calibration
            run_phase_cal=run_phase_cal, # run phase calibration
            run_cross_pol_phase_cal=run_cross_pol_phase_cal, 
            use_inpainted_data=use_inpainted_data,
            smoothing_scale=smoothing_scale
        )
            
        # Write out calibration parameters and metadata
        calibration_parameters["freqs"] = stack.freq_array
        calibration_parameters["flags"] = stack.flags[:, 0, :]
        calibration_parameters["times"] = stack.times
        calibration_parameters["antpairs"] = stack.antpairs
        calibration_parameters["lst"] = stackconf.lst_grid[i]
        
        all_calibration_parameters[stackconf.lst_grid[i]] = calibration_parameters
        
    # Get the calibration filename
    cal_fname = get_fname(
        fname_format=fname_format.replace("{blchunk}", "{blchunk:03}").replace(".sum.uvh5", ".pkl"),
        kind='LSTCAL',
    )
    outfile = outdir / cal_fname
        
    # Write out calibration parameters
    with open(outfile, 'wb') as handle:
        pickle.dump(all_calibration_parameters, handle, protocol=pickle.HIGHEST_PROTOCOL)
            
        
    # Recompute auto and cross averages after calibration - needed for STD files 
    cross_lstavg = [
        lstbin.averaging.reduce_lst_bins(
            lststack=stack,
            inpainted_mode=inpaint_method,
            get_mad=write_med_mad,
        ) for stack in cross_stacks
    ]
    autos_lstavg = [
        lstbin.averaging.reduce_lst_bins(
            lststack=stack,
            inpainted_mode=inpaint_method,
            get_mad=write_med_mad,
        ) for stack in auto_stacks
    ]
elif do_lstcal and n_bl_chunks > 1:    
    # turn LSTCal calibration path into Path object
    lstcal_path = Path(lstcal_path)
    
    # Get the calibration filename
    cal_fname = get_fname(
        fname_format=fname_format.replace("{blchunk}", "{blchunk:03}").replace(".sum.uvh5", ".pkl"),
        kind='LSTCAL', blchunk=0
    )
    
    with open(lstcal_path / cal_fname, 'rb') as calibration_file:
        all_calibration_parameters = pickle.load(calibration_file)
        
    for i, (stack, auto_stack) in enumerate(zip(cross_stacks, auto_stacks)):
        # Check that current LST-grid point matches the grid found in the calibration file
        if stackconf.lst_grid[i] not in all_calibration_parameters:
            raise KeyError(f"LST-grid value {stackconf.lst_grid[i]} not found in calibration parameter file")
        
        # Load in calibration parameters for given LST-grid
        calibration_parameters = all_calibration_parameters[stackconf.lst_grid[i]]
                
        # Get unique antennas in stack
        gain_ants = set()
        for ap in stack.antpairs:
            gain_ants.update(ap)
        gain_ants = list(gain_ants)
    
        # Add tip-tilt gain parameters to calibration_parameters
        unique_pols = list(
            set(sum(map(list, [utils.split_pol(pol) for pol in stack.pols]), []))
        )
        # Get the grid
        transformed_antpos = redcal.reds_to_antpos(reds_with_pols)
        abscal._put_transformed_array_on_integer_grid(transformed_antpos)
        # turn solution into per-antenna gains
        phase_gains = {}
        for (ant, pol) in itertools.product(gain_ants, unique_pols):
            # Calculate gains from tip-tilt calibration parameters
            tip_tilt = calibration_parameters[f"T_{pol}"]
            tip_tilt = np.where(np.isfinite(tip_tilt), tip_tilt, 0)
            phase_gains[(ant, pol)] = np.exp(1j * np.dot(tip_tilt, transformed_antpos[ant])[:, 0, :])
            # If relative phase calibration was done, apply gains
            if "delta" in calibration_parameters and pol == "Jee":
                phase_gains[(ant, pol)] *= np.exp(1j * calibration_parameters["delta"])
                
        
        # Compute gains and smooth
        gains = _expand_degeneracies_to_ant_gains(
            stack,
            amplitude_parameters={
                f"A_{pol}": calibration_parameters[f"A_{pol}"] for pol in unique_pols
            },
            phase_gains=phase_gains,
            inpaint_bands=inpaint_bands,
            auto_stack=auto_stack
        )
        
        # Apply calibration solutions
        for polidx, pol in enumerate(stack.pols):
            # Loop through baselines
            for apidx, (ant1, ant2) in enumerate(stack.antpairs):
                antpol1, antpol2 = utils.split_bl((ant1, ant2, pol))
                # Compute gain and calibrate out
                bl_gain = gains[antpol1] * gains[antpol2].conj()
                stack.data[:, apidx, :, polidx] /= bl_gain
            # Loop through autos
            for apidx, (ant1, ant2) in enumerate(auto_stack.antpairs):
                antpol1, antpol2 = utils.split_bl((ant1, ant2, pol))
                # Compute gain and calibrate out
                auto_gain = gains[antpol1] * gains[antpol2].conj()
                auto_stack.data[:, apidx, :, polidx] /= auto_gain
        
    # Recompute auto and cross averages after calibration - needed for STD files 
    cross_lstavg = [
        lstbin.averaging.reduce_lst_bins(
            lststack=stack,
            inpainted_mode=inpaint_method,
            get_mad=write_med_mad,
        ) for stack in cross_stacks
    ]
    autos_lstavg = [
        lstbin.averaging.reduce_lst_bins(
            lststack=stack,
            inpainted_mode=inpaint_method,
            get_mad=write_med_mad,
        ) for stack in auto_stacks
    ]
CPU times: user 1min 11s, sys: 24.4 s, total: 1min 35s Wall time: 1min 35s
Autos¶
Compute Stats for Autos¶
auto_stats = [
    lstmet.LSTBinStats.from_reduced_data(
        rdc=rdc, antpairs=stackconf.autopairs, pols=stackconf.pols, reds=reds_with_pols if data_is_redundantly_averaged else None
    ) for rdc in autos_lstavg
]
Inpaint Autos¶
_INPAINT_CACHE_ = {}
if inpaint_method in['simultaneous', 'per-night']:
    auto_inpaint_dpss_models = []
    for i, stack in enumerate(auto_stacks):
        
        _avg, dpss_models, inp_flags = avg.average_and_inpaint_simultaneously(
            stack,
            stack,
            inpaint_bands = inpaint_bands,
            return_models = make_plots,
            cache = _INPAINT_CACHE_,
            filter_properties = {
                "min_dly": inpaint_mindelay, 
                "horizon": inpaint_horizon,
                "standoff": inpaint_standoff, 
            },
            eigenval_cutoff=[inpaint_eigencutoff], 
            max_gap_factor=inpaint_max_gap_factor,
            max_convolved_flag_frac=inpaint_max_convolved_flag_frac,
            use_unbiased_estimator=inpaint_use_unbiased_estimator,
            sample_cov_fraction=inpaint_sample_cov_fraction,
            use_night_to_night_cov=inpaint_method=='simultaneous',
            spws=spws,
            post_inpaint_flag_buffer_size=inpaint_spw_buffer_size,
        )
        auto_inpaint_dpss_models.append(dpss_models)
        autos_lstavg[i]['data'] = _avg['data']
        autos_lstavg[i]['flags'] = _avg['flags']
        autos_lstavg[i]['nsamples'] = np.sum(np.where(inp_flags.flags | stack.flags, 0, stack.nsamples),axis=0)
Plot¶
def make_auto_plot(auto_stacks: list[UVData], lstbin: list[dict]):
    
    fig, ax = plt.subplots(
        len(stackconf.autopairs)*len(stackconf.pols), len(auto_stacks), 
        sharex=True, squeeze=False, constrained_layout=True,
        figsize=(12, 6)
    )
    for i, (stack, avg) in enumerate(zip(auto_stacks, lstbin)):
        for j, autopair in enumerate(stackconf.autopairs):
            for p, pol in enumerate(stackconf.pols):
                axx = ax[j*len(stackconf.pols) + p, i]
                
                for k, t in enumerate(stack.times):
                    flg = stack.get_flags(autopair + (pol,))[k]
                    d = stack.get_data(autopair+(pol,))[k]
                    axx.plot(
                        stack.freq_array / 1e6,
                        np.where(flg, np.nan, np.abs(d)),
                        label=f"{int(t)}" if not p else None,
                        **styles[int(t)]
                    )
                    axx.set_yscale('log')
                    axx.set_title(f"Pair {autopair}, pol={pol}, LST {stackconf.lst_grid[i]*12/np.pi:.3f} hr")
                # plot the mean
                axx.plot(
                    stack.freq_array / 1e6,
                    np.where(avg['flags'][j, :, p], np.nan, np.abs(avg['data'][j, :, p])),
                    label='LSTBIN',
                    color='k', lw=2
                )
    ax[0,0].legend(ncols=3)
    for axx in ax[-1]:
        axx.set_xlabel("Frequency [MHz]")
        
def make_auto_plot_multi_autos(auto_stacks: list[UVData], lstbin: list[dict]):
    
    fig, ax = plt.subplots(
        len(stackconf.pols), len(auto_stacks), 
        sharex=True, squeeze=False, constrained_layout=True,
        figsize=(12, 2 + len(stackconf.pols)*1.5)
    )
    for i, (stack, avg) in enumerate(zip(auto_stacks, lstbin)):
        for p, pol in enumerate(stackconf.pols):
            axx = ax[p, i]
        
            for k, t in enumerate(stack.times):
                flg = stack.flags[:, k, :, p]
                d = np.where(flg, np.nan, np.abs(stack.data[:, k, :, p]))
                percentiles = np.nanpercentile(d, [1, 99], axis=0)
                mean = np.nanmean(d, axis=0)
                
                axx.fill_between(
                    stack.freq_array / 1e6,
                    percentiles[0],
                    percentiles[1],
                    alpha=0.4,
                    **styles[int(t)]
                )
                axx.plot(
                    stack.freq_array / 1e6,
                    mean,
                    label=f"{int(t)}" if not p else None,
                    **styles[int(t)]
                )
                
                dd = np.where(np.isnan(d), 1000, d)
                p0 = np.where(np.isnan(percentiles[0]), 0, percentiles[0])
                outliers = np.where(np.any(dd < p0, axis=1))[0].tolist()
                
                dd = np.where(np.isnan(d), -1000, d)
                p0 = np.where(np.isnan(percentiles[1]), np.inf, percentiles[1])
                outliers += np.where(np.any(dd > p0, axis=1))[0].tolist()
                outliers = np.unique(outliers)
                
                for idx in outliers:
                    plt.plot(
                        stack.freq_array / 1e6,
                        d[idx],
                        lw=1,
                        **styles[int(t)]
                    )
                    
                axx.set_yscale('log')
                axx.set_title(f"pol={pol}, LST {stackconf.lst_grid[i]*12/np.pi:.3f} hr")
            # plot the mean (over all autos)
            axx.plot(
                stack.freq_array / 1e6,
                np.nanmean(np.where(avg['flags'][:, :, p], np.nan, np.abs(avg['data'][:, :, p])), axis=0),
                label='mean',
                lw=2,
                color='k'
            )
                
    ax[0,0].legend(ncols=3)
    for axx in ax[-1]:
        axx.set_xlabel("Frequency [MHz]")
if make_plots:
    if len(stackconf.autopairs)>1:
        make_auto_plot_multi_autos(auto_stacks, autos_lstavg)
    else:
        make_auto_plot(auto_stacks, autos_lstavg)
Cross-Pairs¶
Improve Flags¶
from scipy.signal import convolve
def watershed_ndim(metrics: np.ndarray, flags: np.ndarray, axis: int, threshold: float, size: int = 1):
    """Perform a watershed filter over one axis of a multi-dimensional array."""
    assert metrics.shape == flags.shape
    outflags = flags.copy()
    is_neighbour_flagged = np.zeros_like(outflags)
    
    ndim = metrics.ndim
    shape = np.ones(ndim, dtype=int)
    shape[axis] = 2*size + 1
    kernel = np.zeros(shape)
        
    while True:
        nflags = np.sum(outflags)
        is_neighbor_flagged = convolve(outflags, kernel, mode='same', method='direct').astype(bool)        
        outflags |= (is_neighbor_flagged & (metrics >= threshold))
        if np.sum(outflags) == nflags:
            break
    
    return outflags
def iterative_flagger(func, stack, variance, max_iter=10):
    niter = 0
    nflags = np.sum(stack.flags)
    
    while niter < max_iter:
        print(f"    iter {niter}: nflags = {nflags} ({nflags*100/stack.flag_array.size:.2f} %)")
        zsq = lstmet.get_squared_zscores_flagged(stack, variance=variance)
        func(stack, zsq)
        new_nflags = np.sum(stack.flags)
        if  new_nflags == nflags:
            break
        nflags = new_nflags
        niter += 1
    
    return zsq
def do_flagging(funcs, stacks, auto_stats, max_iter=10):
    out_zsq = []
    for i, (stack, stat) in enumerate(zip(stacks, auto_stats)):
        print(f"Stack {i}")
        variance = lstmet.get_nightly_predicted_variance_stack(stack, stat, flag_if_inpainted=True) / 2
        
        for fnc in funcs:
            print(f"  Flagger: {fnc.__name__}")
            zsq = iterative_flagger(fnc, stack, variance, max_iter=max_iter)
            
        out_zsq.append(zsq)
    return out_zsq
    
def direct_zscore_pruning(stack,zsq):
    stack.flags |= ((zsq.metrics > zscore_threshold**2) & (zsq.metrics >= np.nanmax(zsq.metrics, axis=0) / iterative_flagging_factor))
    
def watershed(stack, zsq):
    stack.flags |= watershed_ndim(zsq.metrics, stack.flags, axis=-2, threshold=watershed_threshold**2, size=1)
zsquare = [lstmet.get_squared_zscores_flagged(stack, auto_stats=stats) for stack, stats in zip(cross_stacks, auto_stats)]
if do_extra_flagging:
    # Keep a copy of the original flags and Z^2 so we can check for differences later
    original_flags = [stack.flag_array.copy() for stack in cross_stacks]
    original_zsquare = zsquare
    zsquare = do_flagging([direct_zscore_pruning, watershed], cross_stacks, auto_stats, max_iter=max_flagging_iterations)
if any(np.any(zsq.metrics > exception_for_zsq_above) for zsq in zsquare):
    raise RuntimeError(f"The maximum zsquares before inpainting are {[np.nanmax(zsq.metrics) for zsq in zsquareq]}, which is over the threshold for raising an error")
Inpaint Crosses¶
We simultaneously inpaint and average the data with the flags we're given, using the covariance of a DPSS fit to the mean as a constraint.
if make_plots:
    # We need this if we want to make a plot comparing the new simultaneous inpaint
    original_data_mean = [lstavg['data'].copy() for lstavg in cross_lstavg]
%%time
if inpaint_method in ['simultaneous', 'per-night']:
    inpaint_dpss_models = []
    for i, (stack, auto_stack) in enumerate(zip(cross_stacks, auto_stacks)):
        _avg, dpss_models, inp_flags = avg.average_and_inpaint_simultaneously(
            stack,
            auto_stack,
            inpaint_bands = inpaint_bands,
            return_models = make_plots,
            cache = _INPAINT_CACHE_,
            filter_properties = {
                "min_dly": inpaint_mindelay, 
                "horizon": inpaint_horizon,
                "standoff": inpaint_standoff, 
            },
            eigenval_cutoff=[inpaint_eigencutoff], 
            max_gap_factor=inpaint_max_gap_factor,
            max_convolved_flag_frac=inpaint_max_convolved_flag_frac,
            use_unbiased_estimator=inpaint_use_unbiased_estimator,
            sample_cov_fraction=inpaint_sample_cov_fraction,
            use_night_to_night_cov=inpaint_method=='simultaneous',
            spws=spws,
            post_inpaint_flag_buffer_size=inpaint_spw_buffer_size,
        )
        inpaint_dpss_models.append(dpss_models)
        cross_lstavg[i]['data'] = _avg['data']
        cross_lstavg[i]['flags'] = _avg['flags']
        cross_lstavg[i]['nsamples'] = np.sum(np.where(inp_flags.flags | stack.flags, 0, stack.nsamples),axis=0)
CPU times: user 2min 51s, sys: 15 s, total: 3min 6s Wall time: 3min 8s
We need to satisfy the condition that all non-finite data is flagged, so we check that here.
if not all(np.all(np.isfinite(_avg['data']) | _avg['flags']) for _avg in cross_lstavg):
    raise RuntimeError("Something is wrong with the LST-binner: got NaN values for the averaged data that aren't flagged!")
Let's make some plots to inspect how the inpainting did.
def get_biggest_inpaint_differences(band, n: int = 5):
    
    mask = np.zeros(cross_stacks[0].Nfreqs, dtype=bool)
    mask[band] = True
    spwmask = np.zeros(cross_stacks[0].Nfreqs, dtype=bool)
    for spw in spws:
        spwmask[spw] = True
    spwmask[~mask] = False
    inpdata = cross_lstavg[0]['data'][:, spwmask]
    flgdata = np.where(
        cross_lstavg[0]['nsamples'][:, spwmask] == 0, np.nan, original_data_mean[0][:, spwmask]
    )
    diff = np.nanmax(np.abs(inpdata - flgdata), axis=1).flatten()
    freqidx = np.argmax(np.abs(inpdata - flgdata), axis=1).flatten()
    diff[np.isnan(diff)] = -1
    idx = np.argsort(diff)
    baselines_with_biggest_differences = []
    for ii in idx[::-1][:n]:
        pol_idx = ii % len(cross_stacks[0].pols)
        ap_idx = ii // len(cross_stacks[0].pols)
        baselines_with_biggest_differences.append((*cross_stacks[0].antpairs[ap_idx], cross_stacks[0].pols[pol_idx], freqidx[ii]))
    return baselines_with_biggest_differences
if make_plots and inpaint_method in ['simultaneous', 'per-night']:
    biggest_inpaint_diffs = {}
    n_biggest_diffs = 1
    for band in spws:
        biggest_inpaint_diffs[(band.start, band.stop)] = get_biggest_inpaint_differences(band, n=n_biggest_diffs)
if make_plots and inpaint_method in ['simultaneous', 'per-night']:
    fig, ax = plt.subplots(len(biggest_inpaint_diffs)*n_biggest_diffs, 1, figsize=(15, 2.5*len(biggest_inpaint_diffs)*n_biggest_diffs), layout='constrained')
    kk = 0
    complete_flags = cross_stacks[0].flagged_or_inpainted()
    for jj, (band, blpol_list) in enumerate(biggest_inpaint_diffs.items()):
        band = slice(*band)
        for ii, blpol in enumerate(blpol_list):
            apidx = cross_stacks[0].antpairs.index(blpol[:2])
            polidx = cross_stacks[0].pols.index(blpol[2])
            worst_idx = blpol[3]
            
            diff = np.abs(
                np.where(
                    cross_lstavg[0]['flags'][apidx, band, polidx], np.nan, original_data_mean[0][apidx, band, polidx]
                ) - cross_lstavg[0]['data'][apidx, band, polidx]
            )
            for subband in spws:
                if subband.start < worst_idx < subband.stop:
                    break
                    
            fq = cross_stacks[0].freq_array[subband] / 1e6
            flagged_mean = np.abs(
                np.where(
                    np.all(cross_stacks[0].flags[:, apidx, subband, polidx], axis=0), np.nan, original_data_mean[0][apidx, subband, polidx]
                )
            )
            ax[kk].plot(fq, flagged_mean,  lw=3, ls='-', color='k', label='flagged mean')
            inpmean =np.abs(cross_lstavg[0]['data'][apidx, subband, polidx])
            ax[kk].plot(fq, inpmean, lw=3, ls='--', color='red', label='simul. inpaint')
            gotlabel = False
            for i, jd in enumerate(cross_stacks[0].nights):
                flg = complete_flags[i, apidx, subband, polidx]
                if np.all(flg):
                    continue
                d = np.abs(cross_stacks[0].data[i, apidx, subband, polidx])
                ax[kk].plot(fq, d, lw=1, alpha=0.6, color=styles[jd]['color'])
                if np.any(flg):
                    ax[kk].scatter(
                        fq[flg], d[flg], marker='o', edgecolors=styles[jd]['color'], facecolors='none', alpha=0.4, 
                        label='flagged datum' if not gotlabel else None
                    )
                    gotlabel = True
            ax[kk].legend(ncols=2)
            ax[kk].set_title(f"({int(blpol[0])}, {int(blpol[1])}, {blpol[2]})")
            plotmin = min(np.nanmin(flagged_mean), np.nanmin(inpmean))
            plotmax = max(np.nanmax(flagged_mean), np.nanmax(inpmean))
            rng = plotmax - plotmin
            ax[kk].set_ylim(plotmin - rng/5, plotmax + rng/5)
            kk += 1
    fig.suptitle("Examples of Biggest Differences in Inpainted Solutions")
Write Out Averaged Data¶
So far, we have the cross_lstavg as either a flagged-mode average, or a simultaneously-inpainted average (if do_simultaneous_inpainting==True). If the configuration requests
old-style per-night inpainting, then do it here:
if inpaint_method=='pre-inpainted':
    cross_lstavg = [
        lstbin.averaging.reduce_lst_bins(
            lststack=stack,
            inpainted_mode=True,
            get_mad=write_med_mad,
        ) for stack in cross_stacks
    ]
# make it a bit easier to create the outfiles
create_empty_uvd = partial(
    lstbin.io.create_empty_uvd,
    pols=stackconf.pols,
    file_list=stackconf.matched_metas,
    history=history,
    start_jd=stackconf.properties['first_jd'],
    freq_min=freq_min,
    freq_max=freq_max,
    lst_branch_cut=stackconf.properties["lst_branch_cut"],
    lsts=stackconf.lst_grid,
)
create_file = partial(
    lstbin.io.create_lstbin_output_file,
    outdir=outdir,
    overwrite=overwrite,
)
if save_lstbin_data:
    out_files = {}
    kinds = ["LST", "STD"]
    if write_med_mad:
        kinds += ["MED", "MAD"]
    
    for kind in kinds:
        if blchunk == 0:
            # Save the autos as well.
            fname = get_fname(
                fname_format = fname_format.replace("{blchunk}", "autos"), kind=kind
            )
            auto_uvd_template = create_empty_uvd(antpairs=auto_stacks[0].antpairs)
            out_files[(kind, 'autos')] = create_file(fname=fname, uvd_template=auto_uvd_template)
        fname = get_fname(kind=kind)
        cross_uvd_template = create_empty_uvd(antpairs=cross_stacks[0].antpairs)
        out_files[(kind, 'cross')] = create_file(fname=fname, uvd_template=cross_uvd_template)
def write_data(rdc: dict, stack: lstbin.LSTStack, lstidx: int, pairs: str, template):
    chunk_size = stack.Nbls
    write = partial(
        template.write_uvh5_part,
        blt_inds=np.arange(chunk_size) * stackconf.n_lsts + lstidx,
        flag_array=rdc['flags'],
    )
    
    write(
        filename=out_files[("LST", pairs)],
        data_array=rdc["data"],
        nsample_array=rdc["nsamples"],
    )
    print(f"Wrote {out_files[('LST', pairs)]}")
    
    write(
        filename=out_files[("STD", pairs)],
        data_array=rdc["std"],
        nsample_array=rdc["days_binned"],
    )
    print(f"Wrote {out_files[('STD', pairs)]}")
    
    if write_med_mad:
        write(
            filename=out_files[("MED", pairs)],
            data_array=rdc["median"],
            nsample_array=rdc["nsamples"],
        )
        print(f"Wrote {out_files[('MED', pairs)]}")
        
        write(
            filename=out_files[("MAD", pairs)],
            data_array=rdc["mad"],
            nsample_array=rdc["days_binned"],
        )
        print(f"Wrote {out_files[('MAD', pairs)]}")
if save_lstbin_data:
    for lstidx in range(stackconf.n_lsts):
        if blchunk==0:
            write_data(autos_lstavg[lstidx], auto_stacks[lstidx], lstidx, 'autos', template=auto_uvd_template)
        write_data(cross_lstavg[lstidx], cross_stacks[lstidx], lstidx, 'cross', template=cross_uvd_template)
        
Wrote /lustre/aoc/projects/hera/h6c-analysis/IDR2/lstbin-outputs/nonavg-smoothcal-inpaint-500ns-lstcal/inpaint/zen.LST.6.28319.003.sum.uvh5
Wrote /lustre/aoc/projects/hera/h6c-analysis/IDR2/lstbin-outputs/nonavg-smoothcal-inpaint-500ns-lstcal/inpaint/zen.STD.6.28319.003.sum.uvh5
Wrote /lustre/aoc/projects/hera/h6c-analysis/IDR2/lstbin-outputs/nonavg-smoothcal-inpaint-500ns-lstcal/inpaint/zen.LST.6.28319.003.sum.uvh5
Wrote /lustre/aoc/projects/hera/h6c-analysis/IDR2/lstbin-outputs/nonavg-smoothcal-inpaint-500ns-lstcal/inpaint/zen.STD.6.28319.003.sum.uvh5
if not get_metrics:
    # The rest of the notebook is simply getting metrics and inspecting them.
    print_metadata()
    sys.exit(0)
Distributions of $Z^2$¶
cross_stats = [
    lstmet.LSTBinStats.from_reduced_data(
        rdc=rdc, antpairs=cross_stacks[0].antpairs, pols=stackconf.pols, reds=reds_with_pols if data_is_redundantly_averaged else None
    ) for rdc in cross_lstavg
]
zdist_pred = lststat.zsquare()
Simple Histogram¶
def autoplot(zsquare):
    
    fig, ax = plt.subplots(1, 2, figsize=(14, 5))
    x = np.logspace(-5, 7, 200)
    xc = 10**((np.log10(x[1:]) + np.log10(x[:-1]))/2)
    dndzsq = zdist_pred.pdf(xc)
    for i, zsq in enumerate(zsquare):
        for j, pol in enumerate(zsq.pols):
            ax[0].hist(zsq.metrics[..., j].flatten(), bins=x, label=f"LST Bin {i}. Pol '{pol}'", density=True, histtype='step', lw=3 if pol[0]==pol[1] else 1, ls=['-', ':'][i], color=f'C{j}', alpha=0.7)
    ax[0].plot(xc, dndzsq, color='k', ls ='--', label='Predicted')
    ax[0].set_xscale('log')
    ax[0].set_yscale('log')
    ax[0].set_ylim(1e-12, 1e4)
    ax[0].legend()
    ax[0].set_xlabel(r"Log10 $Z^2$")
    ax[0].set_title("PDF of $Z^2$")
    # Plot the CDF
    x = np.logspace(-5, np.log10(max(np.nanmax(zsq.metrics) for zsq in zsquare)), 100)
    for zsq in zsquare:
        for j, pol in enumerate(zsq.pols):
            zsqm = zsq.metrics[..., j]
            zsqm = zsqm[np.isfinite(zsqm)]
            size = zsqm.size
            cdf_data = [np.sum(zsqm < c)/size for c in x]
            ax[1].plot(x, cdf_data, lw=3 if pol[0]==pol[1] else 1, ls=['-', ':'][i], color=f'C{j}', alpha=0.7)
    
    ax[1].plot(x, zdist_pred.cdf(x), color='k', ls='--')
    ax[1].set_xlabel(r"$Z^2$")
    ax[1].set_title("CDF of $Z^2$")
    ax[1].set_xscale('log')
if make_plots:
    autoplot(zsquare)
Get list of bads¶
def consecutive(data: np.ndarray, stepsize: int=1) -> list[tuple[int, int]]:
    """From https://stackoverflow.com/a/46606745/1467820"""
    sequences = np.split(data, np.where(np.diff(data) != stepsize)[0]+1)
    
    l = []
    for s in sequences:
        if len(s) > 1:
            l.append((s[0], s[-1]))
        else:
            l.append((s[0], s[0]+1))
            
    return l
allbad = {}
inpainted_regions = {}
for lstidx, (zuv, stack) in enumerate(zip(zsquare, cross_stacks)):
    inpaint_flags = stack.inpainted()
    for iap, (a, b) in enumerate(zuv.antpairs):
        for ipol, pol in enumerate(zuv.pols):
            for night, zsqn in enumerate(zuv.metrics[:, iap, :, ipol]):
                key = (lstidx, a, b, ipol, jdint)
                jdint = zuv.nights[night]
                # Get contiguous regions of bad data, i.e. outside "3-sigma" or 99.85%
                badfreqs = np.nonzero(zsqn > zdist_pred.ppf(0.9985))[0]
                if len(badfreqs) > 0:
                    ranges = consecutive(badfreqs)
                    for rng in ranges:
                        allbad[key+(rng[0], rng[1])] = zsqn[rng[0]:rng[1]]
                if stackconf.inpaint_files is not None:
                    badfreqs = np.nonzero(inpaint_flags[night, iap, :, ipol])[0]
                    if len(badfreqs) > 0:
                        ranges = consecutive(badfreqs)
                        for rng in ranges:
                            allbad[key+(rng[0], rng[1])] = zsqn[rng[0]:rng[1]]
    
                else:
                    # Get contiguous regions of inpainted data (anything that is flagged outside FM)
                    for band in inpaint_bands:
                        badfreqs = np.nonzero(stack.flags[night, iap, band, ipol])[0]
                        if len(badfreqs) > 0:
                            ranges = consecutive(badfreqs)
                            for rng in ranges:
                                inpainted_regions[key+(rng[0], rng[1])] = zsqn[band][rng[0]:rng[1]]
    
if save_metric_data:
    # Write out the "bad" data
    fname = get_fname(kind="HIGHZ")
    with h5py.File(outdir / fname, 'w') as fl:
        fl['indices'] = np.array(list(allbad.keys()))  # integer array
        fl['zsq'] = (np.concatenate(tuple(allbad.values())) if len(allbad) > 0 else np.array([]))
chunk_lengths_all = [
    [b - a for _, _, _, polidx, _, a, b in allbad.keys() if polidx==i]
    for i in range(len(cross_stacks[0].pols))
]
for i, chunk_lengths in enumerate(chunk_lengths_all):
    if len(chunk_lengths) > 0:
        print(f"Biggest Frequency Chunk With |Z|>3 for Pol {cross_stacks[0].pols[i]}: ", np.max(chunk_lengths))
Biggest Frequency Chunk With |Z|>3 for Pol ee: 140 Biggest Frequency Chunk With |Z|>3 for Pol en: 107 Biggest Frequency Chunk With |Z|>3 for Pol ne: 131 Biggest Frequency Chunk With |Z|>3 for Pol nn: 141
Histogram of freq-chunk size¶
if make_plots:
    for i, chunk_lengths in enumerate(chunk_lengths_all):
        if len(chunk_lengths) > 0:
            plt.hist(chunk_lengths, bins=np.arange(np.min(chunk_lengths), np.max(chunk_lengths)+1), label=cross_stacks[0].pols[i], histtype='step')
            plt.yscale('log')
            plt.xlabel("Channel-Chunk Length with $|Z|^2$ outside 99.7% of dist.")
            plt.ylabel("Number of Occurences");
        elif make_plots:
            print(f"No |Z| > 3 data found for pol {cross_stacks[0].pols[i]}")
    plt.legend()
BoxPlots of Z^2 across axis chunks¶
def _set_boxplot_ax_props(nboxes: int, ax):
    #ax.axhline(zdist_pred.ppf(0.5), ls='-', color='gray')
    ax.fill_between([-0.5, nboxes-0.5], [zdist_pred.ppf(0.25)]*2, [zdist_pred.ppf(0.75)]*2, color='gray', alpha=0.2)
    ax.fill_between([-0.5, nboxes-0.5], [zdist_pred.ppf(0.01)]*2, [zdist_pred.ppf(0.99)]*2, color='gray', alpha=0.2)
    
    ax.axhline(zdist_pred.mean(), ls='--', color='C3', lw=1)
    ax.set_ylim(1e-1, None)
    ax.set_xlim(-0.5, nboxes-0.5)
    
    ax.set_yscale('log')
    ax.set_ylabel(r"$Z^2$")
zsq_flags = [(stack.flagged_or_inpainted() | (~np.isfinite(zsq.metrics))) for zsq, stack in zip(zsquare, cross_stacks)]
def box_plot_all_groups(zscores, stacks):
    fig, axx = plt.subplots(len(subsets), 1, sharex=True, figsize=(12, 3*len(subsets)), layout='constrained')
    allbls = [(a, b, p) for a, b in stackconf.antpairs for p in stackconf.pols]
    for j, (name, selector) in enumerate(subsets.items()):
        ax = axx[j]
            
        for i, band in enumerate(bands_considered):
            for n, night in enumerate(data_jd_ints):
                allz = lstmet.get_compressed_zscores(
                    zscores, band=band, nights=night, bl_selectors=selector,
                    flags = zsq_flags
                )
                
                bplot = ax.boxplot(
                    allz, positions = [i-0.3 + 0.05*n], 
                    showfliers=False, whis=(1,99), showmeans=True,
                    tick_labels=[f"chs {band[0]}-{band[1]}" if (n==len(data_jd_ints)//2 and j==(len(subsets)-1)) else ""], 
                )
                bplot['boxes'][0].set_color(styles[night]['color'])
                bplot['boxes'][0].set_linestyle(styles[night]['ls'])
                bplot['whiskers'][0].set_color(styles[night]['color'])
                bplot['whiskers'][0].set_linestyle(styles[night]['ls'])
                bplot['whiskers'][1].set_color(styles[night]['color'])
                bplot['whiskers'][1].set_linestyle(styles[night]['ls'])
                
                bplot['caps'][0].set_color(styles[night]['color'])
                bplot['caps'][1].set_color(styles[night]['color'])
                
                bplot['means'][0].set_marker("*")
                bplot['means'][0].set_markerfacecolor(styles[night]['color'])
                bplot['means'][0].set_markeredgecolor(styles[night]['color'])
                bplot['means'][0].set_markersize(10)
                
                if i==0 and j==0:
                    # Dummy lines for legend
                    ax.plot([1,2], [np.nan, np.nan], **styles[night], label=str(night))
                    
        _set_boxplot_ax_props(len(bands_considered), ax)
        ax.set_ylabel(name.replace(" baselines", ""))
    
    axx[0].legend(ncols=3)
    
    return axx
if make_plots:
    box_plot_all_groups(zsquare, cross_stacks);
Mean Z^2 Over Different Axes¶
metrics = {}
def reduce_zsquare(axis, **kw):
    return [
        lstmet.reduce_stack_over_axis(
            np.ma.mean, lstmet.downselect_zscores(zsq, flags=flg, **kw), axis=axis
        )  for zsq, flg in zip(zsquare, zsq_flags)
    ]
metrics['band_reduced_mean'] = {}
for band in bands_considered:
    metrics['band_reduced_mean'][band] = reduce_zsquare(band=band, axis='freqs')
metrics['bl_reduced_mean'] = {}
allbls = [(a,b,p) for a,b in stackconf.antpairs for p in stackconf.pols]
for j, (name, selector) in enumerate(subsets.items()):
    metrics['bl_reduced_mean'][name] = reduce_zsquare(bl_selectors=selector, axis='bls')
metrics['night_reduced_mean'] = reduce_zsquare(axis='nights')
metrics['night_and_bl_reduced_mean'] = {}
for j, (name, selector) in enumerate(subsets.items()):    
    metrics['night_and_bl_reduced_mean'][name] = reduce_zsquare(bl_selectors=selector, axis=("nights", "bls"))
metrics['night_and_band_reduced_mean'] = {}
for band in bands_considered:
    metrics['night_and_band_reduced_mean'][band] = reduce_zsquare(band=band, axis=('nights', 'freqs'))
metrics['bl_and_band_reduced_mean'] = {}
for j, (name, selector) in enumerate(subsets.items()):
    for band in bands_considered:
        metrics['bl_and_band_reduced_mean'][(band, name)] = reduce_zsquare(band=band, bl_selectors=selector, axis=('bls', 'freqs'))
metrics['all_reduced_mean'] = {}
for j, (name, selector) in enumerate(subsets.items()):    
    for band in bands_considered:
        metrics['all_reduced_mean'][(band, name)] = reduce_zsquare(band=band, bl_selectors=selector, axis=('bls', 'freqs', 'nights'))
Plot Totally Reduced¶
subset_styles = {name: {'color': f"C{i%len(subsets)}", 'ls': ['-', '--', ':', '-.'][i//4]} for i, name in enumerate(subsets.keys())}
if make_plots:
    done = set()
    for (band, subset_name), means in metrics['all_reduced_mean'].items():
        mid = np.mean(band)
        size=0 if band[1]-band[0]==200 else (1 if band[1]-band[0] < 1500 else 2)
        # Note that what we are plotting here is the Z^2/2, whose distribution has a mean
        # of one. Unfortunately the exact distribution is dependent on the number of 
        # unflagged data in each, and instead of tracking that, for such a simple 
        # plot we just approximate the variance as 1/P, where P is the number of 
        # unflagged data in the sample.
        plt.errorbar(
            [np.mean(band)], 
            means[0]/2, 
            xerr=[[mid - band[0]]], 
            marker='ox*'[size], 
            markersize=8, 
            **subset_styles[subset_name], 
            label=subset_name.replace("baselines", "") if subset_name not in done else None
        )
        done.add(subset_name)
    plt.legend(ncols=2)
    plt.axhline(1)
    p = np.sum(~zsq_flags[0])
    plt.axhline(1 + 3/p, color='gray')
    plt.yscale('log')
    plt.ylabel("Mean $|Z|^2/2$ over bls, channels, nights")
    plt.xlabel("Channel")
Plot Reduced over Nights + Bands¶
def make_baseline_zsq_plot(zscores, stack):
    # TODO: need a better cmap to easily see what's "good" and "bad"
    
    fig, axx = plt.subplots(len(bands_considered)-3, len(stackconf.pols), sharey=True, figsize=(24, 5*(len(bands_considered)-3)), layout='constrained')
    
    cmap = mpl.colors.ListedColormap(["C0", f"C1", f"C3"])
    for i, band in enumerate(bands_considered):
        if band[1] - band[0] > 200:
            continue
        ax = axx[i]
        
        mean_zsq = metrics['night_and_band_reduced_mean'][band][0]
    
        uvws = stack.uvw_array[:stack.Nbls][:, :2]
        uvws[uvws[:, 1] < 0] *= -1
        for ipol, pol in enumerate(stackconf.pols):
            cbar = ax[ipol].scatter(uvws[:, 0], uvws[:, 1], c=mean_zsq[:, ipol], norm=mpl.colors.LogNorm( vmin=1, vmax=1000), marker='H', s=60, cmap=cmap)
            ax[ipol].set_title(pol)
            ax[ipol].set_aspect("equal", 'datalim')
            ax[ipol].set_xlim(-200, 200)
            ax[ipol].grid(True)
        
        ax[0].set_ylabel(str(band))
        
        plt.colorbar(cbar, ax=ax)
if make_plots:
    make_baseline_zsq_plot(zsquare, cross_stacks[0])
Plot Reduced over Nights and bls¶
def plot_excess_variance_wrt_freq():
    for subset, zsq in metrics['night_and_bl_reduced_mean'].items():
        # do the mean over the two LST bins here...
        zsq = np.nanmean(zsq, axis=0)
        
        plt.plot(stackconf.config.datameta.freq_array / 1e6, zsq/2, label=subset.replace("baselines", ""), **subset_styles[subset])
        
    plt.xlabel("Freq [MHz]")
    plt.ylabel(r"Mean $Z^2$ across Nights, LSTs and Baselines")
    plt.legend(ncols=2)
    plt.ylim(7e-1, 100)
    plt.yscale('log')
if make_plots:
    plot_excess_variance_wrt_freq()
Plot Reduced over Bls¶
def plot_reduced_over_bls(zstack):
    images = {}
    for subset, zsqs in metrics['bl_reduced_mean'].items():
        images[subset] = np.ones((len(data_jd_ints), len(zstack.freq_array)), dtype=float) * np.nan
        for ijd, jd in enumerate(data_jd_ints):
            if jd not in zstack.nights:
                continue
            jdidx = zstack.nights.tolist().index(jd)
            images[subset][ijd] = zsqs[0][jdidx]/2
    nrows = int(np.ceil(len(subsets)/3))
    fig, ax = plt.subplots(nrows, 3, sharex=True, sharey=True, layout='constrained', figsize=(14, 3*nrows))
    cmap = mpl.colors.ListedColormap(["C0", "C1", "C3"])
    for i, (key, img) in enumerate(images.items()):
        axx = ax.flatten()[i]
        plt.sca(axx)
        cbar = plt.imshow(
            img, norm=mpl.colors.LogNorm( vmin=1, vmax=1000),
            origin='lower',
            extent=(
                zstack.freq_array.min()/1e6, 
                zstack.freq_array.max()/1e6,
                0,
                len(data_jd_ints)
            ),
            cmap=cmap, aspect='auto',
            interpolation='none',
        )
        axx.yaxis.set_ticks(np.arange(img.shape[0]) +0.5)
        axx.yaxis.set_ticklabels(data_jd_ints)
        axx.set_title(key.replace("baselines", ""), pad=-3)
        if i < 3:
            axx.tick_params('x', labeltop=True, labelbottom=False, top=True)
    for j in range(i+1, ax.size):
        ax.flatten()[j].axis('off')
    cbar = plt.colorbar(cbar, ax = ax)
    cbar.set_label(r"Mean $Z^2$ over bl subset")
if make_plots:
    plot_reduced_over_bls(zsquare[0])
Plot Selection of the Worst Visibilities¶
def plot_visibilities_per_type(
    lstbin_blpols: list[tuple[int, tuple[int, int, str]]], 
    stacks: list[UVData],
    stats: list[lstmet.LSTBinStats],
    auto_stats: list[lstmet.LSTBinStats],
    comments: list[str],
    zscores: list[lstbin.binning.LSTStack],
    freq_range=None | tuple[float, float] | list[tuple[int, int]], 
    label=None, 
    yrange=None,
    alpha=0.5,
):
    all_figs = []
    
    lststyle = dict(color='k', lw=3, zorder=-1)
    meta = stackconf.config.datameta
    
    # Get a mask that says which channels are *simultaneously* inpainted.
    simul_inpmask = np.zeros(meta.Nfreqs, dtype=bool)
    for band in inpaint_bands:
        simul_inpmask[band] = True
    if isinstance(freq_range, tuple):
        mask = (meta.freq_array >= freq_range[0]) & (meta.freq_array < freq_range[1])
        freqs=meta.freq_array[mask]/1e6
    else:
        mask = slice(None)
        freqs = meta.freq_array/1e6
    handles = []
    for jdint, style in styles.items():
        handles.append(mpl.lines.Line2D([0], [0], label=str(jdint), alpha=alpha, **style))
            
    for i, (comment, (lstidx, blpol)) in enumerate(zip(comments, lstbin_blpols)):
        if isinstance(freq_range, list):
            this_range = freq_range[i]
            
            # pad the range a bit
            this_range = (max(this_range[0] - 100, 0), min(this_range[1]+100, 1536))
            mask = slice(this_range[0], this_range[1])
            freqs = meta.freq_array[mask]/1e6
            
        stack = stacks[lstidx]
        zscore = zscores[lstidx]
        
        rawd = stack.get_data(blpol)[:, mask]        
        rawf = stack.get_flags(blpol)[:, mask]
        rawn = stack.get_nsamples(blpol)[:, mask]
        
        if np.all(rawn >= 0):
            inp = rawf & simul_inpmask[mask]
        else:
            inp = rawn < 0
        
        lstf = stats[lstidx].flags[blpol][mask]
        lstd = stats[lstidx].mean[blpol][mask]
        
        lstmed = lstd  # actually mean
        
        iap = zscore.antpairs.index(blpol[:2])
        ipol = zscore.pols.index(blpol[2])
        
        zsq = zscore.metrics[:, iap, mask, ipol]
        
        if np.all(lstf):
            print("ALL FLAGGED")
            continue
            
        fig, ax = plt.subplots(
            4, 2, 
            sharex=True, figsize=(15, 8), 
            constrained_layout=True, gridspec_kw={'height_ratios': (2,1,2,1)}
        )
        
        mag = np.where(rawf, np.nan, np.abs(rawd))
        rl = np.where(rawf, np.nan, rawd.real)
        im = np.where(rawf, np.nan, rawd.imag)
        
        maglstbin = np.where(lstf, np.nan, np.abs(lstd))
        rllstbin = np.where(lstf, np.nan, lstd.real)
        imlstbin = np.where(lstf, np.nan, lstd.imag)
        
        rllstbin_med = np.where(lstf, np.nan, lstmed.real)
        imlstbin_med = np.where(lstf, np.nan, lstmed.imag)
                
        pred_std = np.sqrt(lstmet.get_nightly_predicted_variance(blpol, stack=stack, auto_stats = auto_stats[lstidx]) / 2)[:, mask]
        
        ax[0, 0].plot(freqs, maglstbin, **lststyle)
        ax[0, 1].plot(freqs, rllstbin, **lststyle)                
        ax[2, 1].plot(freqs, imlstbin, **lststyle)
        
        for jdidx, jdint in enumerate(stack.nights):
            style = styles[jdint]
            if np.all(rawf[jdidx]):
                continue
            thisinp = inp[jdidx]
            if np.any(thisinp):
                inp_ranges = consecutive(np.nonzero(thisinp)[0])
            else:
                inp_ranges = []
            
            # Amplitude and Phase
            ax[0, 0].plot(freqs, mag[jdidx], **style)
            for rng in inp_ranges:
                ax[0, 0].fill_between(freqs[rng[0]:rng[1]], mag[jdidx, rng[0]:rng[1]], maglstbin[rng[0]:rng[1]], color=style['color'], alpha=0.2)
                
            ax[1, 0].plot(freqs, mag[jdidx] - maglstbin, **style)
            for rng in inp_ranges:
                ax[1, 0].fill_between(freqs[rng[0]:rng[1]], mag[jdidx, rng[0]:rng[1]] - maglstbin[rng[0]:rng[1]], 0, color=style['color'], alpha=0.2)
            
            ax[2, 0].plot(freqs, zsq[jdidx], **style)
            for rng in inp_ranges:
                ax[2, 0].fill_between(freqs[rng[0]:rng[1]], zsq[jdidx, rng[0]:rng[1]], 0, color=style['color'], alpha=0.2)
            # Real / Imag
            ax[0, 1].plot(freqs, rl[jdidx], **style)
            for rng in inp_ranges:
                ax[0, 1].fill_between(freqs[rng[0]:rng[1]], rl[jdidx, rng[0]:rng[1]], rllstbin[rng[0]:rng[1]], color=style['color'], alpha=0.2)
            
            rldiff = (rl[jdidx] - rllstbin_med)/pred_std[jdidx]
            ax[1, 1].plot(freqs, rldiff, **style)
            for rng in inp_ranges:
                ax[1,1].fill_between(freqs[rng[0]:rng[1]], rldiff[rng[0]:rng[1]], 0, color=style['color'], alpha=0.2)
            
            ax[2, 1].plot(freqs, im[jdidx], **style)
            for rng in inp_ranges:
                ax[2, 1].fill_between(freqs[rng[0]:rng[1]], im[jdidx, rng[0]:rng[1]], imlstbin[rng[0]:rng[1]], color=style['color'], alpha=0.2)
            
            imdiff = (im[jdidx] - imlstbin_med)/pred_std[jdidx]
            ax[3, 1].plot(freqs, imdiff, **style)
            for rng in inp_ranges:
                ax[3,1].fill_between(freqs[rng[0]:rng[1]], imdiff[rng[0]:rng[1]], 0, color=style['color'], alpha=0.2)
            if yrange:
                ax[0, 0].set_ylim(yrange)
        ax[1,1].axhline(4, color='gray', ls='--')
        ax[1,1].axhline(-4, color='gray', ls='--')
        ax[3,1].axhline(4, color='gray', ls='--')
        ax[3,1].axhline(-4, color='gray', ls='--')
            
        bl_coords = stackconf.config.datameta.antpos_enu[blpol[0]] - stackconf.config.datameta.antpos_enu[blpol[1]]
        
        fig.suptitle(
            f"Baseline: {blpol} [{bl_coords[0]:.1f}-EW, {bl_coords[1]:.1f}-NS]. "
            f"LST = {stackconf.lst_grid[0]*12/np.pi:.5f} hr."
        )
        ax[-1, 0].set_xlabel("Frequency [MHz]")
        ax[-1, 1].set_xlabel("Frequency [MHz]")
        
        ax[0, 0].set_ylabel("Magnitude")
        ax[0, 1].set_ylabel("Real Part")
        
        ax[1, 0].set_ylabel("Magnitude Diff")
        ax[1, 1].set_ylabel("Real Z-score")
        ax[1, 1].set_ylim(-7, 7)
        
        ax[2, 0].set_ylabel(r"$Z^2$")
        ax[2, 0].set_yscale('log')
        ax[2, 0].set_ylim(1e-1,)
        
        ax[2, 1].set_ylabel("Imag Part")
        
        #ax[3, 0].set_ylabel("Phase Diff")
        ax[3, 1].set_ylabel("Imag Z-score")
        ax[3, 1].set_ylim(-7, 7)
        ax[0, 0].legend(handles=handles, ncols=5)
        ax[0,1].text(0.95, 0.95, comment, transform=ax[0,1].transAxes, ha='right', va='top')
        for axx in ax.flatten():
            for line in range(0, 1536, 200):
                axx.axvline(meta.freq_array[line]/1e6, color='gray', alpha=0.4)
            axx.set_xlim(freqs[0], freqs[-1])
        yield fig
        plt.close(fig)
def get_worst_mean_over_each_band(zscores, n=1, autopols_only=True):
    bad_fellas = {}
    nights0 = [data_jd_ints.index(jd) for jd in zscores[0].times.astype(int)]
    nights1 = [data_jd_ints.index(jd) for jd in zscores[1].times.astype(int)]
    
    if autopols_only:
        polidx = [i for i, p in enumerate(stackconf.pols) if len(set(p))==1]
    else:
        polidx = np.arange(len(stackconf.pols))
        
    npols = len(polidx)
    
    newmeans = {band: np.ones((len(zscores), len(data_jd_ints), len(zscores[0].antpairs), npols))*np.nan for band in metrics['band_reduced_mean']}
    
    for band, zsqs in metrics['band_reduced_mean'].items():
        # zsqs is length(lstbins), where each is an array of shape (nights, antpairs, pols)
        # however, the number of nights for each lstbin could be different, so make them the same here....
        newmeans[band][0, nights0] = zsqs[0][..., polidx]
        newmeans[band][1, nights1] = zsqs[1][..., polidx]
        
    lst_night_bl_pols = [(lst, jd, bl + (pol,)) for lst in range(len(zscores)) for jd in data_jd_ints for bl in zscores[0].antpairs for j, pol in enumerate(stackconf.pols) if j in polidx]
    
    for band, zsq in newmeans.items():
        zsq = np.where(np.isnan(zsq.flatten()), -1, zsq.flatten())
        
        worst_idx = np.argpartition(zsq, -n)[-n:]
        worst_zsq = zsq[worst_idx]
        worst_idx = worst_idx[np.argsort(-worst_zsq)]
        
        for idx, z in zip(worst_idx, worst_zsq):
            lst, jd, bl = lst_night_bl_pols[idx]
            
            if (lst, bl) not in bad_fellas:
                bad_fellas[(lst, bl)] = []
                
            bad_fellas[(lst, bl)].append((jd, z, fr"Worst $Z^2$ in band {band[0]}-{band[1]}", band))
    return bad_fellas
def get_worst_mean_for_continuously_bad_stuff(zscores, n=1, autopols_only=True):
    
    bad_fellas = {}
    
    chsizes = [(1, 2), (2, 10), (10, 20), (20, 50), (50, 100), (100, 1536)]
    sized = {ch: {} for ch in chsizes}
    for k, v in allbad.items():
        s = k[-1] - k[-2]  # size of chunk
        if s == 1:
            continue
        
        for i, ch in enumerate(chsizes):
            if ch[0] <= s < ch[1]:
                sized[ch][k] = v
    if autopols_only:
        polidx = set([i for i, p in enumerate(stackconf.pols) if len(set(p))==1])
    else:
        polidx = set(range(len(stackconf.pols)))
        
    for chsize, thesebads in sized.items():
        if not thesebads:
            continue
            
        keys = [k for k in thesebads.keys() if k[3] in polidx]
        meanz = np.array([np.nanmean(v) for k, v in thesebads.items() if k[3] in polidx])
        
        nn = min(n, len(meanz))
        worst_idx = np.argpartition(meanz, -nn)[-nn:]
        worst_zsq = meanz[worst_idx]
        worst_idx = worst_idx[np.argsort(-worst_zsq)]
        for idx, z in zip(worst_idx, worst_zsq):
            lst, a, b, pol, jdint, low, high = keys[idx]
            bl = (a, b, stackconf.pols[pol])
            
            if (lst, bl) not in bad_fellas:
            
                bad_fellas[(lst, bl)] = []
            bad_fellas[(lst, bl)].append((int(jdint), z, fr"Worst $Z^2$ over {chsize[0]}-{chsize[1]} channels",(low, high)))
    return bad_fellas
def get_worst_continuous_bad_zscore(zscores, n=1, autopols_only=True):
    bad_fellas = {}
    nights0 = [data_jd_ints.index(jd) for jd in zscores[0].times.astype(int)]
    nights1 = [data_jd_ints.index(jd) for jd in zscores[1].times.astype(int)]
    
    smallbands = [b for b in bands_considered if b[1] - b[0] <= 200]
    
    if autopols_only:
        polidx = [i for i, p in enumerate(stackconf.pols) if len(set(p))==1]
    else:
        polidx = np.arange(len(stackconf.pols))
        
    npols = len(polidx)
    newmeans = np.ones(
        (len(smallbands), len(zscores), len(data_jd_ints), len(zscores[0].antpairs), npols)
    )*np.nan
    
    for i, band in enumerate(smallbands):
        zsqs = metrics['band_reduced_mean'][band]
        # zsqs is length(lstbins), where each is an array of shape (nights, antpairs, pols)
        # however, the number of nights for each lstbin could be different, so make them the same here....
        newmeans[i, 0, nights0] = zsqs[0][..., polidx]
        newmeans[i, 1, nights1] = zsqs[1][..., polidx]
    lst_night_bl_pols = [(lst, jd, bl + (pol,)) for lst in range(len(zscores)) for jd in data_jd_ints for bl in zscores[0].antpairs for j, pol in enumerate(stackconf.pols) if j in polidx]
    zsq = np.nanmin(newmeans, axis=0)
    
    zsq = np.where(np.isnan(zsq).flatten(), -1, zsq.flatten())
    
    nn = min(n, len(zsq))
    worst_idx = np.argpartition(zsq, -nn)[-nn:]
    worst_zsq = zsq[worst_idx]
    worst_idx = worst_idx[np.argsort(-worst_zsq)]
    for idx, z in zip(worst_idx, worst_zsq):
        lst, jd, bl = lst_night_bl_pols[idx]
        if (lst, bl) not in bad_fellas:
            bad_fellas[(lst, bl)] = []
        bad_fellas[(lst, bl)].append((jd, z, fr"Worst min($Z^2$) over entire band", (0, 1536)))
    return bad_fellas
def get_bad_inpaints(zscores, n=1, autopols_only=True):
    
    bad_fellas = {}
    
    nights = [zsq.nights.tolist() for zsq in zscores]
    if autopols_only:
        polidx = set([i for i, p in enumerate(stackconf.pols) if len(set(p))==1])
    else:
        polidx = set(range(len(stackconf.pols)))
    chsizes = [(2, 5), (5, 10), (10, 20)]    
    sized = {ch: {} for ch in chsizes}
    for k, v in inpainted_regions.items():
        s = k[-1] - k[-2]  # size of chunk
        if s == 1:
            continue
        
        for i, ch in enumerate(chsizes):
            if ch[0] <= s < ch[1]:
                sized[ch][k] = v
                
    for chsize, bads in sized.items():
        
        
        keys = [k for k in bads.keys() if k[3] in polidx]
        
        meanz = np.array([
            np.nanmean(zscores[lst].metrics[nights[lst].index(jdint), zscores[lst].antpairs.index((a,b)), low:high, pol]) 
            for (lst, a, b, pol, jdint, low, high) in bads.keys() if pol in polidx
        ])
        
        nn = min(n, len(meanz))
        worst_idx = np.argpartition(meanz, -nn)[-nn:]
        worst_zsq = meanz[worst_idx]
        worst_idx = worst_idx[np.argsort(-worst_zsq)]
        for idx, z in zip(worst_idx, worst_zsq):
            lst, a, b, pol, jdint, low, high = keys[idx]
            bl = (a, b, stackconf.pols[pol])
            
            if (lst, bl) not in bad_fellas:
            
                bad_fellas[(lst, bl)] = []
            bad_fellas[(lst, bl)].append((int(jdint), z, fr"Worst inpainted $Z^2$ for {chsize[0]}-{chsize[1]} chans", (low, high)))
    return bad_fellas
if make_plots:
    worst_mean_over_each_band = get_worst_mean_over_each_band(zsquare, n=plot_n_worst)
    worst_mean_for_continously_bad = get_worst_mean_for_continuously_bad_stuff(zsquare, n=plot_n_worst)
    worst_minimum_zscores_over_bands = get_worst_continuous_bad_zscore(zsquare, n=plot_n_worst)
    worst_inpainted_regions = get_bad_inpaints(zsquare, n=plot_n_worst)
if make_plots:
    # Merge all the things that we want to take a closer look at
    badstuff = {}
    for dct in (worst_mean_over_each_band, worst_mean_for_continously_bad, worst_minimum_zscores_over_bands, worst_inpainted_regions):
        for k, v in dct.items():
            if k not in badstuff:
                badstuff[k] = []
            badstuff[k].extend(v)
if make_plots:
    freq_ranges = [sum((vv[-1] for vv in v), start=()) for v in badstuff.values()]
    freq_ranges = [(min(v), max(v)) for v in freq_ranges]
    for fig in plot_visibilities_per_type(
        lstbin_blpols= list(badstuff.keys()),
        stacks= cross_stacks,
        stats= cross_stats,
        comments=["\n".join([f"{vv[-2]}: {vv[0]}" for vv in v]) for v in badstuff.values()],
        freq_range=freq_ranges,
        alpha=0.5,
        zscores=zsquare,
        auto_stats=auto_stats
    ):
        plt.show()