Single Night LST-Binned Calibration Notebook¶

by Tyler Cox and Josh Dillon, last updated on April 14th, 2026

This notebook performs LST-binned calibration (or LST-cal) on whole-JD, single-baseline, all polarization files. Most parameters are controlled by a toml config file, such as this one. In addition to single-baseline files, this notebook also requires UVFlag-compatible where_inpainted files which tell us where inpainting was previously done.

To keep the total memory footprint of the notebook reasonable, the full list of baselines are first downselected to a subset of the Nbls redundant baseline types that have the largest number of nsamples, and then compute the redundant-calibration degenerate parameters (namely, the per-frequency/time amplitude, tip-tilt, and cross-polarized phase degeneracies) that bring a single night into better alignment with the LST-average. The calibration parameters are then smoothed in time and frequency with DPSS basis functions given a user specified time and frequency smoothing scale.

Here's a set of links to skip to particular figures and tables:

• Figure 1: Amplitude Parameters Before/After Smoothing¶

• Figure 2: Tip/Tilt Parameters Before/After Smoothing¶

• Figure 3: Cross-Polarized Phase Before/After Smoothing¶

• Figure 4: Visibility/LST-Averaged Variance Across Baseline Before/After LST-Cal¶

In [1]:
import time
tstart = time.time()
!hostname
!date
herapost006
Tue May 12 04:57:26 MDT 2026
In [2]:
import jax
jax.config.update('jax_platform_name', 'cpu') # Force jax to use CPU if GPU available

import re
import os
import yaml
import glob
import toml
import numpy as np
import pylab as plt
from copy import deepcopy
from functools import reduce
from tqdm.notebook import tqdm
from concurrent.futures import ProcessPoolExecutor, as_completed


from pyuvdata import UVData
from hera_cal.lst_stack import LSTConfig
from hera_cal import lst_stack, io, flag_utils, abscal, datacontainer, redcal, utils, smooth_cal, red_groups
from hera_qm.time_series_metrics import true_stretches
from hera_filters.dspec import fourier_filter, dpss_operator
from hera_cal.lst_stack.calibration import _expand_degeneracies_to_ant_gains
from hera_cal.lst_stack.config import LSTBinConfiguratorSingleBaseline, make_lst_grid
from hera_cal.lst_stack.binning import SingleBaselineStacker, _get_freqs_chans, adjust_lst_bin_edges, _allocate_dfn, get_lst_bins

from hera_qm.metrics_io import read_a_priori_ant_flags

import warnings
warnings.filterwarnings("ignore", module="hera_cal")
%load_ext line_profiler

import jax
jax.config.update('jax_platform_name', 'cpu') # Force jax to use CPU if GPU available
In [3]:
toml_file = os.environ.get(
    'TOML_FILE', '/lustre/aoc/projects/hera/h6c-analysis/IDR3/src/hera_pipelines/pipelines/h6c/idr3/v1/lstbin/single_bl_lst_stack.toml'
)
print(f'toml_file = "{toml_file}"')

jd_here = os.environ.get('JD', None)
if jd_here is not None:
    jd_here = int(jd_here)
print(f'jd_here = {jd_here}')
toml_file = "/lustre/aoc/projects/hera/h6c-analysis/IDR3/src/hera_pipelines/pipelines/h6c/idr3/v1/lstbin/lststack_131_nights/single_bl_lst_stack.toml"
jd_here = 2459858
In [4]:
# get options from toml file, print them out, and update globals
toml_options = toml.load(toml_file)

print(f"Now setting the following global variables from {toml_file}:\n")

globals().update({'lst_branch_cut': toml_options['FILE_CFG']['lst_branch_cut']})
print(f"lst_branch_cut = {lst_branch_cut}")

globals().update({'where_inpainted_file_rules': toml_options['FILE_CFG']['where_inpainted_file_rules']})
print(f"where_inpainted_file_rules = {where_inpainted_file_rules}")

# this is used for an initial stacking of a handful of baselines, which are then used for LSTCal
toml_options['LST_STACK_OPTS']['FNAME_FORMAT'] = toml_options['LST_STACK_OPTS']['FNAME_FORMAT'].replace('.sum.uvh5', '.preliminary.sum.uvh5')

for key, val in toml_options['LSTCAL_OPTS'].items():
    if isinstance(val, str):
        print(f'{key} = "{val}"')
    else:
        print(f'{key} = {val}')
globals().update(toml_options['LSTCAL_OPTS'])
        
for key, val in toml_options['LST_STACK_OPTS'].items():
    if isinstance(val, str):
        print(f'{key} = "{val}"')
    else:
        print(f'{key} = {val}')
globals().update(toml_options['LST_STACK_OPTS'])
Now setting the following global variables from /lustre/aoc/projects/hera/h6c-analysis/IDR3/src/hera_pipelines/pipelines/h6c/idr3/v1/lstbin/lststack_131_nights/single_bl_lst_stack.toml:

lst_branch_cut = 5.2
where_inpainted_file_rules = [['.reinpainted.uvh5', '.where_reinpainted.h5']]
NBLS_FOR_LSTCAL = 30
RUN_AMPLITUDE_CAL = True
RUN_TIP_TILT_PHASE_CAL = True
RUN_CROSS_POL_PHASE_CAL = True
INCLUDE_AUTOS = False
FREQ_SMOOTHING_SCALE = 30.0
TIME_SMOOTHING_SCALE = 2500
BLACKLIST_TIMESCALE_FACTOR = 10
BLACKLIST_RELATIVE_ERROR_THRESH = 0.1
BLACKLIST_NITER = 5
WHERE_INPAINTED_WGTS = 0.0001
LSTCAL_FNAME_FORMAT = "single_bl_reinpaint_1000ns/zen.{night}.lstcal.hdf5"
MAX_BL_LENGTH = 100
MIN_AVG_REDUNDANCY = 0.0
OUTDIR = "/lustre/aoc/projects/hera/h6c-analysis/IDR3/lststack-outputs-131-nights"
FNAME_FORMAT = "single_bl_reinpaint_1000ns/zen.LST.baseline.{bl_str}.preliminary.sum.uvh5"
USE_LSTCAL_GAINS = True
FM_LOW_FREQ = 87.5
FM_HIGH_FREQ = 108.0
EIGENVAL_CUTOFF = 1e-12
INPAINT_DELAY = 1000
AUTO_INPAINT_DELAY = 100
REINPAINT_AUTOS = False
INPAINT_WIDTH_FACTOR = 0.5
INPAINT_ZERO_DIST_WEIGHT = 0.01
MIN_NIGHTS_PER_BIN = 3
MAX_CHANNEL_NIGHTLY_FLAG_FRAC = 0.25
MOD_Z_TO_REINPAINT = 5
AUTO_FR_SPECTRUM_FILE = "/lustre/aoc/projects/hera/zmartino/hera_frf/spectra_cache/spectra_cache_hera_auto.h5"
GAUSS_FIT_BUFFER_CUT = 1e-05
CG_TOL = 1e-06
CG_ITER_LIM = 500
FR0_FILTER = True
FR0_HALFWIDTH = 0.01
FR0_FILT_AUTOS = False
In [5]:
# Baseline configurator
configurator = lst_stack.config.LSTBinConfiguratorSingleBaseline.from_toml(toml_file)
auto_baseline_string = [s for s in configurator.bl_to_file_map if (p := s.split('_'))[0] == p[1]][0]

# Get metadata for LST-stacking
hd = io.HERAData(
    configurator.bl_to_file_map[auto_baseline_string][-1]
)
df = np.median(np.diff(hd.freqs))
dlst = np.median(np.diff(hd.lsts))
lst_grid = lst_stack.config.make_lst_grid(dlst, begin_lst=0, lst_width=(2 * np.pi)) # _fix_dlst function making the grid the wrong size
lst_bin_edges =  np.concatenate([lst_grid - dlst / 2, (lst_grid[-1] + dlst / 2)[None]])

# Julian dates for LST-calibration
jds = [int(night) for night in configurator.nights]
filepath = toml_options['FILE_CFG']['datafiles']['datadir']
aposteriori_yamls = {jd: filepath + f'/{jd}/{jd}_aposteriori_flags.yaml' for jd in jds}

1. Load Single Night Data and Rephase to Correct LST-bins¶

In [6]:
# Get baseline keys from the files that have been previously saved by the stacking notebook
single_bl_files = glob.glob(
    os.path.join(OUTDIR, FNAME_FORMAT.format(bl_str="*"))
)

pattern = '^' + re.escape(os.path.join(OUTDIR, FNAME_FORMAT)).replace(r'\{bl_str\}', r'(?P<a>\d+)_(?P<b>\d+)') + '$'
rx = re.compile(pattern)

baselines = []
baseline_strings = []
for file in single_bl_files:
    match = rx.match(file)
    i, j = map(int, match.groups())
    baselines.append((i, j))
    baseline_strings.append("{}_{}".format(i, j))
In [7]:
RUN_CALIBRATION = jd_here is not None

if RUN_CALIBRATION:
    # Pre-compute rounding factor once — only depends on dlst, not per-baseline
    lst_grid_rounding_factor = np.abs(np.floor(np.log10(dlst) - 2)).astype(int)

    single_day_config = deepcopy(configurator)
    single_day_config.nights = [str(jd_here)]
    single_day_config.bl_to_file_map = {
        bl_str: [file for file in configurator.bl_to_file_map[bl_str] if str(jd_here) in file]
        for bl_str in configurator.bl_to_file_map
    }

    def process_baseline(bl_string):
        crosses = SingleBaselineStacker.from_configurator(
            single_day_config,
            bl_string,
            lst_bin_edges,
            lst_branch_cut=lst_branch_cut,
            where_inpainted_file_rules=where_inpainted_file_rules
        )

        antpair = crosses.hd.antpairs[0]
        bl_to_load = os.path.join(OUTDIR, FNAME_FORMAT.format(bl_str=bl_string))
        polarizations = crosses.hd.pols if RUN_CROSS_POL_PHASE_CAL else crosses.hd.pols[:2]

        hd = io.HERAData(bl_to_load)
        single_bl_stacked_data, lst_avg_flags, _ = hd.read(polarizations=polarizations)

        lst_grid = hd.lsts.copy()
        lst_grid[lst_grid[0] > lst_grid] += 2 * np.pi
        indices = np.searchsorted(
            np.round(lst_grid, lst_grid_rounding_factor),
            np.round(crosses.bin_lst, lst_grid_rounding_factor)
        )

        # Pre-compute once outside the pol loop — same for all polarizations
        times_per_bin = [len(t) for t in crosses.times_in_bins]

        local_wgts = {}
        local_data_for_cal = {}
        local_model = {}
        local_where_inpainted = {}
        local_all_flagged = {}
        local_cross_pol_data = {}
        local_cross_pol_model = {}

        for pi, pol in enumerate(polarizations):
            # Consolidate all concatenations upfront for this polarization
            nsamples   = np.concatenate([n[..., pi] for n in crosses.nsamples], axis=0)
            flags      = np.concatenate([f[..., pi] for f in crosses.flags], axis=0)
            winp_stack = np.concatenate([w[..., pi] for w in crosses.where_inpainted], axis=0)
            data_stack = np.concatenate([d[..., pi] for d in crosses.data], axis=0)

            model_flags = np.concatenate([
                np.repeat(lst_avg_flags[antpair + (pol,)][[idx]], len(tinb), axis=0)
                for idx, tinb in zip(indices, crosses.times_in_bins)
            ], axis=0)
            flags |= model_flags

            local_wgts[antpair + (pol,)] = nsamples * (~flags).astype(float)

            # Bug fix: was `if pol in flags` (ndarray), should be checking the accumulator dict
            if pol in local_all_flagged:
                local_all_flagged[pol] &= flags
            else:
                local_all_flagged[pol] = flags

            if pol in local_where_inpainted:
                local_where_inpainted[pol] &= winp_stack
            else:
                local_where_inpainted[pol] = winp_stack

            # Use np.repeat instead of broadcasting with np.ones to avoid allocating an intermediate array
            p1, p2 = utils.split_pol(pol)
            if p1 == p2:
                local_data_for_cal[antpair + (pol,)] = data_stack
                local_model[antpair + (pol,)] = np.repeat(
                    single_bl_stacked_data[antpair + (pol,)][indices], times_per_bin, axis=0
                )
            else:
                local_cross_pol_data[antpair + (pol,)] = data_stack
                local_cross_pol_model[antpair + (pol,)] = np.repeat(
                    single_bl_stacked_data[antpair + (pol,)][indices], times_per_bin, axis=0
                )

        return {
            'wgts':            local_wgts,
            'data_for_cal':    local_data_for_cal,
            'model':           local_model,
            'where_inpainted': local_where_inpainted,
            'all_flagged':     local_all_flagged,
            'cross_pol_data':  local_cross_pol_data,
            'cross_pol_model': local_cross_pol_model,
            'freqs':           crosses.hd.freqs,
            'times':           np.concatenate(crosses.times_in_bins),
        }

    # Initialize output dicts
    data_for_cal = {}
    wgts = {}
    model = {}
    where_inpainted = {}
    all_flagged = {}
    freqs = None
    times = None
    if RUN_CROSS_POL_PHASE_CAL:
        cross_pol_data = {}
        cross_pol_model = {}

    # Parallelize across baselines, merging each result as it arrives to keep memory usage low
    with ProcessPoolExecutor(max_workers=os.cpu_count()) as executor:
        futures = {executor.submit(process_baseline, bl): bl for bl in baseline_strings}

        for future in tqdm(as_completed(futures), total=len(baseline_strings)):
            result = future.result()

            # Merge immediately and drop the reference so memory can be freed
            wgts.update(result['wgts'])
            data_for_cal.update(result['data_for_cal'])
            model.update(result['model'])

            for pol, arr in result['where_inpainted'].items():
                if pol in where_inpainted:
                    where_inpainted[pol] &= arr
                else:
                    where_inpainted[pol] = arr

            for pol, arr in result['all_flagged'].items():
                if pol in all_flagged:
                    all_flagged[pol] &= arr
                else:
                    all_flagged[pol] = arr

            if RUN_CROSS_POL_PHASE_CAL:
                cross_pol_data.update(result['cross_pol_data'])
                cross_pol_model.update(result['cross_pol_model'])

            # Grab freqs/times from first completed result
            if freqs is None:
                freqs = result['freqs']
                times = result['times']

            del result  # Explicitly drop to help GC

    # Explicitly use a known baseline key rather than relying on loop-end state
    single_jd_hd = io.HERAData(single_day_config.bl_to_file_map[baseline_strings[0]])
    single_jd_times = single_jd_hd.times

else:
    print("Not running calibration since JD is not set.")
In [8]:
%%time
if RUN_CALIBRATION:
    autos = SingleBaselineStacker.from_configurator(
        single_day_config,
        auto_baseline_string,
        lst_bin_edges,
        lst_branch_cut=lst_branch_cut, 
        where_inpainted_file_rules=where_inpainted_file_rules
    )
    # Get antennas that make up the baseline
    antpair = autos.hd.antpairs[0]
    polarizations = autos.hd.pols if RUN_CROSS_POL_PHASE_CAL else autos.hd.pols[:2]
    
    auto_model = {}
    for pi, pol in enumerate(polarizations):
        auto_model[antpair + (pol,)] = np.concatenate([_data[..., pi] for _data in autos.data], axis=0)
        p1, p2 = utils.split_pol(pol)
        if INCLUDE_AUTOS and p1 == p2:
            data_for_cal[antpair + (pol,)] = np.concatenate([_data[..., pi] for _data in autos.data], axis=0)
            nsamples = np.concatenate([_nsamples[..., pi] for _nsamples in autos.nsamples], axis=0)
            flags = np.concatenate([_flags[..., pi] for _flags in autos.flags], axis=0)
            wgts[antpair + (pol,)] = nsamples * (~flags).astype(float)

    if INCLUDE_AUTOS:
        autos = SingleBaselineStacker.from_configurator(
            configurator,
            auto_baseline_string,
            lst_bin_edges,
            lst_branch_cut=lst_branch_cut, 
            where_inpainted_file_rules=where_inpainted_file_rules
        )
        lst_avg_autos, _, _ = autos.average_over_nights()
        antpair = autos.hd.antpairs[0]
        
        for pidx, pol in enumerate(polarizations):
            p1, p2 = utils.split_pol(pol)
            if p1 == p2:
                model[antpair + (pol,)] = np.concatenate([
                    lst_avg_autos[[idx], :, pidx] * np.ones((len(tinb), 1))
                    for idx, tinb in zip(indices, crosses.times_in_bins)
                ], axis=0)
CPU times: user 3.7 s, sys: 1.02 s, total: 4.71 s
Wall time: 6.91 s

2. LST-Binned Calibration¶

In [9]:
import jax
from jax import numpy as jnp

def complex_phase_abscal(data, model, reds, data_bls, model_bls, transformed_antpos=None, newton_maxiter=20):
    """
    Stripped down version of hera_cal.abscal.complex_phase_abscal that assumes the tip-tilt solution is close to zero. 
    Calculates gains that would absolute calibrate the phase of already redundantly-calibrated data. 
    Only operates one polarization at a time.

    Parameters:
    ----------
    data : DataContainer or RedDataContainer
        Dictionary-like container mapping baselines to data visibilities to abscal
    model : DataContainer or RedDataContainer
        Dictionary-like container mapping baselines to model visibilities
    reds : list of lists
        List of lists of redundant baselines tuples like (0, 1, 'ee'). Ignored if transformed_antpos is not None.
    data_bls : list of tuples
        List of baseline tuples in data to use.
    model_bls : list of tuples
        List of baseline tuples in model to use. Must correspond the same physical separations as data_bls.
    transformed_antpos : dict
        Dictionary of abstracted antenna positions that you'd normally get from redcal.reds_to_antpos().
        If None, will be inferred from reds.

    Returns:
    -------
    meta : dictionary
        Contains keys for
            'Lambda_sol' : phase gradient solutions,
            'Z_sol' : value of the objective function at the solution,
            'newton_iterations' : number of iterations completed by the Newton's method solver
    delta_gains : dictionary
        Dictionary mapping antenna keys like (0, 'Jee') to gains of the same shape of the data
    """
    # Check that baselines selected are for the same polarization
    pols = list(set([bl[2] for bls in (data_bls, model_bls) for bl in bls]))
    assert len(pols) == 1, 'complex_phase_abscal() can only solve for one polarization at a time.'

    # Get transformed antenna positions and baselines
    if transformed_antpos is None:
        transformed_antpos = redcal.reds_to_antpos(reds)
    abscal._put_transformed_array_on_integer_grid(transformed_antpos)
    transformed_b_vecs = np.rint([transformed_antpos[jj] - transformed_antpos[ii] for (ii, jj, pol) in data_bls]).astype(int)

    # Get number of baselines and times/freqs
    Ngroups = len(data_bls)
    Ntimes, Nfreqs = data[data_bls[0]].shape

    # Build up array of Fourier coefficients of the objective function
    Z_coefficients = np.zeros((Ntimes, Nfreqs, Ngroups), dtype=complex)
    for nn in range(Ngroups):

        Vhat_n = data[data_bls[nn]]
        Vbar_n = model[model_bls[nn]]
        Z_coefficients[:, :, nn] = Vhat_n * np.conj(Vbar_n)

    # Get solution for degenerate phase gradient vectors
    Ntimes, Nfreqs, Ngroups = Z_coefficients.shape
    Ndims = transformed_b_vecs.shape[1]

    Lambda_sol = np.zeros((Ntimes, Nfreqs, Ndims), dtype=float)

    batched_newton = jax.vmap(
        lambda z_coeffs, lambda_init: abscal._newton_solve(
            lambda_init, transformed_b_vecs, z_coeffs, tol=1e-8, maxiter=newton_maxiter
        ),
        in_axes=(0, 0)
    )
    
    Z_coeffs_flat  = jnp.array(Z_coefficients.reshape(Ntimes * Nfreqs, Ngroups))
    Lambda_inits_flat = jnp.array(Lambda_sol.reshape(Ntimes * Nfreqs, Ndims))
    
    Lambda_flat, niters_flat = batched_newton(Z_coeffs_flat, Lambda_inits_flat)
    _ = niters_flat.block_until_ready()
    
    Lambda_sol = np.array(Lambda_flat).reshape(Ntimes, Nfreqs, Ndims)
    newton_iterations = np.array(niters_flat).reshape(Ntimes, Nfreqs)

    # --- vmap _eval_Z for Z_sol ---
    batched_eval_Z = jax.vmap(
        lambda z_coeffs, lam: abscal._eval_Z(lam, transformed_b_vecs, z_coeffs),
        in_axes=(0, 0)
    )
    Z_sol = np.array(
        batched_eval_Z(Z_coeffs_flat, Lambda_flat)
    ).reshape(Ntimes, Nfreqs)
        
    # turn solution into per-antenna gains
    meta = {
        'Lambda_sol': -Lambda_sol, 
        'transformed_antpos': transformed_antpos,
        'Z_sol': Z_sol,
    }
    return meta
In [10]:
%%time
if RUN_CALIBRATION:
    # Amplitude Calibration
    if RUN_AMPLITUDE_CAL:
        amplitude_solutions = abscal.abs_amp_lincal(
            model=model,
            data=data_for_cal,
            wgts=wgts,
            verbose=False,
        )
    else:
        amplitude_solutions = {}
        data_shape = data_for_cal[list(data_for_cal.keys())[0]].shape
        amplitude_solutions['ee'] = np.ones(data_shape)
        amplitude_solutions['nn'] = np.ones(data_shape)
    
    # Tip-tilt Phase Calibration
    if RUN_TIP_TILT_PHASE_CAL:
        # Get the redundancies
        all_reds = red_groups.RedundantGroups.from_antpos(
            antpos=hd.antpos, 
            pols=('nn', 'ee'), 
            include_autos=False
        )
    
        # Fit the tip-tilt for both pols
        phase_solutions = {}
        for pol in ['ee', 'nn']:
            phase_fit = complex_phase_abscal(
                {k: data_for_cal[k][:] for k in data_for_cal if k[-1] == pol and k[0] != k[1]}, 
                {k: model[k][:] for k in model if k[-1] == pol and k[0] != k[1]}, 
                all_reds, 
                [k for k in data_for_cal if k[-1] == pol and k[0] != k[1]], 
                [k for k in model if k[-1] == pol and k[0] != k[1]], 
            )
            phase_solutions[pol] = phase_fit['Lambda_sol']
    
        transformed_antpos = phase_fit['transformed_antpos']   
    else:
        data_shape = data_for_cal[list(data_for_cal.keys())[0]].shape
        phase_solutions = {}
        for pol in ['ee', 'nn']:
            phase_fit = np.zeros(data_shape + (2,))
            phase_solutions[pol] = phase_fit
    
        transformed_antpos = {k: hd.antpos[k][:2] for k in hd.antpos}
        
    if RUN_CROSS_POL_PHASE_CAL:
        cross_pol_phase = abscal.cross_pol_phase_cal(
            model=cross_pol_model,
            data=cross_pol_data,
            model_bls=list(cross_pol_model.keys()),
            data_bls=list(cross_pol_data.keys()),
            wgts=wgts,
        )
    else:
        data_shape = data_for_cal[list(data_for_cal.keys())[0]].shape
        cross_pol_phase = np.zeros(data_shape)
CPU times: user 12min 29s, sys: 40.4 s, total: 13min 9s
Wall time: 6min 35s
In [11]:
if RUN_CALIBRATION:
    # Check blacklisting in amplitude calibration (harder to do in phase)
    blacklist_wgts = {}
    
    avg_wgts = {
        pol: np.mean([wgts[key] for key in wgts if pol in key], axis=0)
        for pol in polarizations
    }
    
    for pol in ['ee', 'nn']:
        if RUN_AMPLITUDE_CAL:
            gains = np.where(
                np.isfinite(amplitude_solutions[f"A_J{pol}"]), 
                amplitude_solutions[f"A_J{pol}"], 
                1.0
            )
            blacklist_wgts_here = np.ones(gains.shape,)
            for i in range(BLACKLIST_NITER):
                wgts_here = np.where(np.isfinite(amplitude_solutions[f"A_J{pol}"]), avg_wgts[pol], 0.0)
                wgts_here = np.where(all_flagged[pol], 0.0, wgts_here)
                wgts_here = np.where(where_inpainted[pol], WHERE_INPAINTED_WGTS, wgts_here)
                wgts_here *= blacklist_wgts_here
                smoothed_amp, _ = smooth_cal.time_freq_2D_filter(
                    gains=gains.astype(complex),
                    wgts=wgts_here,
                    freqs=freqs,
                    times=times,
                    freq_scale=FREQ_SMOOTHING_SCALE,
                    time_scale=TIME_SMOOTHING_SCALE * BLACKLIST_TIMESCALE_FACTOR,
                    eigenval_cutoff=EIGENVAL_CUTOFF,
                    method='DPSS', 
                    use_sparse_solver=True,
                    precondition_solver=True,
                    fix_phase_flips=False, 
                    flag_phase_flip_ints=False,
                    skip_flagged_edges=True, 
                    freq_cuts=[(FM_LOW_FREQ + FM_HIGH_FREQ) * .5e6],
                ) 
                blacklist_wgts_here = np.where(
                    (np.abs(gains - smoothed_amp) / np.abs(smoothed_amp)) > BLACKLIST_RELATIVE_ERROR_THRESH,
                    0.0,
                    1.0
                )
            blacklist_wgts[pol] = blacklist_wgts_here.copy()
        else:
            blacklist_wgts[pol] = np.ones_like(amplitude_solutions[f"A_J{pol}"])
Mean of empty slice
Mean of empty slice
invalid value encountered in divide
Mean of empty slice
Mean of empty slice
invalid value encountered in divide
Mean of empty slice
Mean of empty slice
invalid value encountered in divide
Mean of empty slice
Mean of empty slice
invalid value encountered in divide
Mean of empty slice
Mean of empty slice
invalid value encountered in divide
Mean of empty slice
Mean of empty slice
invalid value encountered in divide
Mean of empty slice
Mean of empty slice
invalid value encountered in divide
Mean of empty slice
Mean of empty slice
invalid value encountered in divide
Mean of empty slice
Mean of empty slice
invalid value encountered in divide
Mean of empty slice
Mean of empty slice
invalid value encountered in divide
In [12]:
if RUN_CALIBRATION:
    smoothed_tip_tilt = {}
    smoothed_amplitude = {}
    
    for pol in ['ee', 'nn']:
        if RUN_AMPLITUDE_CAL:
            gains = np.where(
                np.isfinite(amplitude_solutions[f"A_J{pol}"]), 
                amplitude_solutions[f"A_J{pol}"], 
                1.0
            )
            wgts_here = np.where(np.isfinite(amplitude_solutions[f"A_J{pol}"]), avg_wgts[pol], 0.0)
            wgts_here = np.where(all_flagged[pol], 0.0, wgts_here)
            wgts_here = np.where(where_inpainted[pol], WHERE_INPAINTED_WGTS, wgts_here)
            wgts_here *= blacklist_wgts[pol]
            smoothed_amp, _ = smooth_cal.time_freq_2D_filter(
                gains=gains.astype(complex),
                wgts=wgts_here,
                freqs=freqs,
                times=times,
                freq_scale=FREQ_SMOOTHING_SCALE,
                time_scale=TIME_SMOOTHING_SCALE,
                eigenval_cutoff=EIGENVAL_CUTOFF,
                method='DPSS', 
                use_sparse_solver=True,
                precondition_solver=True,
                fix_phase_flips=False, 
                flag_phase_flip_ints=False,
                skip_flagged_edges=True, 
                freq_cuts=[(FM_LOW_FREQ + FM_HIGH_FREQ) * .5e6],
            ) 
            smoothed_amplitude[pol] = np.where(
                all_flagged[pol], 1.0, smoothed_amp.real
            )
        else:
            smoothed_amplitude[pol] = np.ones_like(amplitude_solutions[pol])
            
        if RUN_TIP_TILT_PHASE_CAL:
            smoothed_solutions = []
            for i in range(2):
                tip_tilt = phase_solutions[pol][..., i].astype(complex)
                gains = np.where(np.isfinite(tip_tilt), tip_tilt, 0.0)
                wgts_here = np.where(np.isfinite(tip_tilt), avg_wgts[pol], 0.0)
                wgts_here = np.where(all_flagged[pol], 0.0, wgts_here)
                wgts_here = np.where(where_inpainted[pol], WHERE_INPAINTED_WGTS, wgts_here)
                wgts_here *= blacklist_wgts[pol]
                
                tip_tilt_smoothed, _ = smooth_cal.time_freq_2D_filter(
                    gains=gains,
                    wgts=wgts_here,
                    freqs=freqs,
                    times=times,
                    freq_scale=FREQ_SMOOTHING_SCALE,
                    time_scale=TIME_SMOOTHING_SCALE,
                    eigenval_cutoff=EIGENVAL_CUTOFF,
                    method='DPSS', 
                    use_sparse_solver=True,
                    precondition_solver=True,
                    fix_phase_flips=False, 
                    flag_phase_flip_ints=False,
                    skip_flagged_edges=True, 
                    freq_cuts=[(FM_LOW_FREQ + FM_HIGH_FREQ) * .5e6],
                ) 
                smoothed_solutions.append(tip_tilt_smoothed.real)
            
            smoothed_tip_tilt[pol] = np.where(
                all_flagged[pol][..., None], 0.0, np.transpose(smoothed_solutions, (1, 2, 0))
            )
        else:
            smoothed_tip_tilt[pol] = np.zeros(wgts[list(wgts.keys())[0]].shape + (2,))
    
    if RUN_CROSS_POL_PHASE_CAL:
        gains = np.where(np.isfinite(cross_pol_phase), cross_pol_phase, 0.0)
        wgts_here = np.where(np.isfinite(cross_pol_phase), avg_wgts['en'] + avg_wgts['ne'], 0.0)
        wgts_here = np.where(all_flagged['ee'] | all_flagged['nn'], 0.0, wgts_here)
        wgts_here = np.where(where_inpainted[pol], WHERE_INPAINTED_WGTS, wgts_here)
        wgts_here *= blacklist_wgts[pol]
        
        cross_pol_smoothed, _ = smooth_cal.time_freq_2D_filter(
            gains=gains.astype(complex),
            wgts=wgts_here,
            freqs=freqs,
            times=times,
            freq_scale=FREQ_SMOOTHING_SCALE,
            time_scale=TIME_SMOOTHING_SCALE,
            eigenval_cutoff=EIGENVAL_CUTOFF,
            method='DPSS', 
            use_sparse_solver=True,
            precondition_solver=True,
            fix_phase_flips=False, 
            flag_phase_flip_ints=False,
            skip_flagged_edges=True, 
            freq_cuts=[(FM_LOW_FREQ + FM_HIGH_FREQ) * .5e6],
        ) 
        cross_pol_smoothed = np.where(
            all_flagged["nn"] | all_flagged["ee"],
            0.0,
            cross_pol_smoothed
        )
    else:
        cross_pol_smoothed = np.zeros(wgts[list(wgts.keys())[0]].shape)
Mean of empty slice
Mean of empty slice
Mean of empty slice
Mean of empty slice
Mean of empty slice
Mean of empty slice
Mean of empty slice
Mean of empty slice
Mean of empty slice
Mean of empty slice
Mean of empty slice
Mean of empty slice
Mean of empty slice
Mean of empty slice
In [13]:
def plot_amplitude_degeneracy():
    ai, aj = list(map(int, baseline_strings[0].split("_")))
    lsts = utils.JD2LST(times) * 12 / np.pi
    if lsts[0] > lsts[-1]:
        wrap_point = (lsts[0] + lsts[-1]) / 2
    else:
        wrap_point = (lsts[0] + lsts[-1] + 24) / 2
    lsts[wrap_point < lsts] -= 24
    extent = [freqs.min() / 1e6, freqs.max() / 1e6, lsts.max(), lsts.min()]
    
    fig, axs = plt.subplots(2, 2, figsize=(15, 10), sharex=True, sharey=True)
    for pi, pol in enumerate(['ee', 'nn']):
        axs[pi, 0].imshow(
            np.where(
                where_inpainted[pol] | (all_flagged[pol]), 
                np.nan, 
                np.abs(amplitude_solutions[f'A_J{pol}'])
            ), 
            aspect='auto', 
            interpolation='None', 
            vmin=0.95, 
            vmax=1.05, 
            extent=extent, 
            cmap='turbo'
        )
        im = axs[pi, 1].imshow(
            np.where(all_flagged[pol], np.nan, np.abs(smoothed_amplitude[pol])),
            aspect='auto', 
            interpolation='None', 
            vmin=0.95, 
            vmax=1.05, 
            extent=extent, 
            cmap='turbo'
        )
        
    plt.tight_layout()
    axs[1, 0].set_xlabel("Frequency (MHz)", fontsize=12)
    axs[1, 1].set_xlabel("Frequency (MHz)", fontsize=12)
    axs[0, 0].set_ylabel("ee", fontsize=12)
    axs[1, 0].set_ylabel("nn", fontsize=12)
    axs[0, 0].set_title("Raw Amplitude Degeneracy", fontsize=14)
    axs[0, 1].set_title("Smoothed Amplitude Degeneracy", fontsize=14)
    cbar = plt.colorbar(im, ax=axs, fraction=0.05, pad=0.01)
    cbar.set_label("Gain Amplitude", fontsize=14)
    fig.text(-0.02, 0.5, 'LST (hr)', va='center', rotation='vertical', fontsize=14)

def plot_tip_tilt_degeneracy():
    ai, aj = list(map(int, baseline_strings[0].split("_")))
    lsts = utils.JD2LST(times) * 12 / np.pi
    if lsts[0] > lsts[-1]:
        wrap_point = (lsts[0] + lsts[-1]) / 2
    else:
        wrap_point = (lsts[0] + lsts[-1] + 24) / 2
    lsts[wrap_point < lsts] -= 24
    extent = [freqs.min() / 1e6, freqs.max() / 1e6, lsts.max(), lsts.min()]

    antvec = np.array([hd.antpos[k][:2] for k in hd.antpos])
    antvec -= antvec[0]
    transformed_antvec = np.array([transformed_antpos[k][:2] for k in hd.antpos])
    transformed_antvec -= transformed_antvec[0]
    coord_scalar = np.diag(np.linalg.solve(transformed_antvec.T.dot(transformed_antvec), transformed_antvec.T.dot(antvec)))
    
    fig, axs = plt.subplots(4, 2, figsize=(15, 12), sharex=True, sharey=True)
    ci = 0
    for pi, pol in enumerate(['ee', 'nn']):
        for ni in range(2):
            axs[ci, 0].imshow(
                np.where(
                    where_inpainted[pol] | (all_flagged[pol]),
                    np.nan,
                    phase_solutions[pol][..., ni] * coord_scalar[ni],
                ),
                aspect='auto', 
                interpolation='None', 
                vmin=-0.05, 
                vmax=0.05, 
                extent=extent, 
                cmap='turbo'
            )
            tip_tilt = smoothed_tip_tilt[pol][..., ni] * coord_scalar[ni]
            im = axs[ci, 1].imshow( 
                np.where(all_flagged[pol], np.nan, tip_tilt),
                aspect='auto', 
                interpolation='None', 
                vmin=-0.05, 
                vmax=0.05, 
                extent=extent, 
                cmap='turbo'
            )

            # Labeling
            axs[ci, 0].set_ylabel(f"{pol} component-{ci%2}")
            
            ci += 1
        
        axs[3, 0].set_xlabel(r"Frequency (MHz)", fontsize=12)
        axs[3, 1].set_xlabel(r"Frequency (MHz)", fontsize=12)

    axs[0, 0].set_title("Raw Tip-Tilt Degeneracy", fontsize=14)
    axs[0, 1].set_title("Smoothed Tip-Tilt Degeneracy", fontsize=14)
    plt.tight_layout()
    cbar = plt.colorbar(im, ax=axs, fraction=0.04, pad=0.01)
    cbar.set_label("Phase Gradient (rad/m)", fontsize=14)
    fig.text(-0.02, 0.5, 'LST (hr)', va='center', rotation='vertical', fontsize=14)

def plot_cross_polarized_phase_degeneracy():
    ai, aj = list(map(int, baseline_strings[0].split("_")))
    lsts = utils.JD2LST(times) * 12 / np.pi
    if lsts[0] > lsts[-1]:
        wrap_point = (lsts[0] + lsts[-1]) / 2
    else:
        wrap_point = (lsts[0] + lsts[-1] + 24) / 2
    lsts[wrap_point < lsts] -= 24
    extent = [freqs.min() / 1e6, freqs.max() / 1e6, lsts.max(), lsts.min()]
    
    fig, axs = plt.subplots(1, 2, figsize=(15, 6), sharex=True, sharey=True)
    pol = 'ne'
    axs[0].imshow(
        np.where(
            where_inpainted[pol] | (all_flagged[pol]),
            np.nan,
            cross_pol_phase,
        ),
        aspect='auto', 
        interpolation='None', 
        vmin=-0.1, 
        vmax=0.1, 
        extent=extent, 
        cmap='coolwarm'
    )
    im = axs[1].imshow( 
        np.where(all_flagged[pol], np.nan, cross_pol_smoothed.real),
        aspect='auto', 
        interpolation='None', 
        vmin=-0.1, 
        vmax=0.1, 
        extent=extent, 
        cmap='coolwarm'
    )

    # Labeling
    axs[0].set_ylabel(r"LST (hr)")    
    axs[0].set_xlabel(r"Frequency (MHz)")
    axs[1].set_xlabel(r"Frequency (MHz)")
    axs[0].set_title("Raw Relative Phase Degeneracy")
    axs[1].set_title("Smoothed Relative Phase Degeneracy")
    plt.tight_layout()
    cbar = plt.colorbar(im, ax=axs, fraction=0.05, pad=0.01)
    cbar.set_label("Jee/Jnn Relative Phase (rad)", fontsize=14)

def expand_degenerate_gains_single_baseline(key, all_calibration_parameters, transformed_antpos, use_cross_pol=True):
    """
    """
    ant1pol, ant2pol = utils.split_bl(key)
    blvec = transformed_antpos[ant2pol[0]] - transformed_antpos[ant1pol[0]]
    gain = all_calibration_parameters[f"A_{ant1pol[1]}"].astype(complex) * all_calibration_parameters[f"A_{ant1pol[1]}"].astype(complex)
    g1 = np.exp(
        1j * np.einsum("tfc,c->tf", all_calibration_parameters[f"T_{ant2pol[1]}"], transformed_antpos[ant2pol[0]])
    )
    g2 = np.exp(
        1j * np.einsum("tfc,c->tf", all_calibration_parameters[f"T_{ant1pol[1]}"], transformed_antpos[ant1pol[0]])
    )
    gain *= g1 * g2.conj()

    if use_cross_pol:
        if ant1pol[-1] != ant2pol[-1]: 
            if ant1pol[-1] == 'Jnn':
                g1 = np.exp(1j * all_calibration_parameters['cross_pol'])
                gain *= g1
            elif ant2pol[-1] == 'Jnn':
                g1 = np.exp(-1j * all_calibration_parameters['cross_pol'])
                gain *= g1
    
    return gain

def plot_excess_variance():
    ai, aj = list(map(int, baseline_strings[0].split("_")))
    lsts = utils.JD2LST(times) * 12 / np.pi

    if lsts[0] > lsts[-1]:
        wrap_point = (lsts[0] + lsts[-1]) / 2
    else:
        wrap_point = (lsts[0] + lsts[-1] + 24) / 2
        
    lsts[wrap_point < lsts] -= 24
    extent = [freqs.min() / 1e6, freqs.max() / 1e6, lsts.max(), lsts.min()]
    
    all_calibration_parameters = {
        "A_Jee": smoothed_amplitude['ee'],
        "A_Jnn": smoothed_amplitude['nn'],
        "T_Jee": smoothed_tip_tilt['ee'],
        "T_Jnn": smoothed_tip_tilt['nn'],
        "cross_pol": cross_pol_smoothed,
    }
    
    if RUN_CROSS_POL_PHASE_CAL:
        fig, axs = plt.subplots(4, 2, figsize=(15, 10), sharey=True, sharex=True)
    else:
        fig, axs = plt.subplots(2, 2, figsize=(15, 6), sharey=True, sharex=True)
    
    for pi, pol in enumerate(['ee', 'nn']):
        excess_var = 0
        excess_var_cal = 0
        count = 0
        weights = 0
        for key in data_for_cal:
            if pol in key and key[0] != key[1]:
                noise_var = np.abs(auto_model[(0, 0, pol)]) ** 2 / wgts[key] / 10 / 122e3
                zsquare = np.abs(data_for_cal[key] - model[key]) ** 2 / noise_var
                excess_var += zsquare * (wgts[key])# * (~where_inpainted[pol]).astype(float))
                gain = expand_degenerate_gains_single_baseline(key, all_calibration_parameters, transformed_antpos, use_cross_pol=True)
                data_cal = data_for_cal[key] / gain
                zsquare = np.abs(data_cal - model[key]) ** 2 / noise_var
                excess_var_cal += zsquare * (wgts[key])# * (~where_inpainted[pol]).astype(float))
                weights += (wgts[key])# * (~where_inpainted[pol]).astype(float))
    
        im = axs[pi, 0].imshow(
            np.real(np.abs(excess_var) / weights), 
            aspect='auto', 
            interpolation='None', 
            cmap='turbo', 
            vmin=0.5, 
            vmax=10, 
            extent=extent
        )
        im = axs[pi, 1].imshow(
            np.real(np.abs(excess_var_cal) / weights), 
            aspect='auto', 
            interpolation='None', 
            cmap='turbo', 
            vmin=0.5, 
            vmax=10, 
            extent=extent
        )
    if RUN_CROSS_POL_PHASE_CAL:
        for pi, pol in enumerate(['en', 'ne']):
            excess_var = 0
            excess_var_cal = 0
            count = 0
            weights = 0
            for key in cross_pol_data:
                if pol in key:
                    noise_var = np.abs(auto_model[(0, 0, "ee")] * auto_model[(0, 0, "nn")]) / wgts[key] / 10 / 122e3
                    zsquare = np.abs(cross_pol_data[key] - cross_pol_model[key]) ** 2 / noise_var
                    excess_var += zsquare * (wgts[key])# * (~where_inpainted[pol]).astype(float))
                    gain = expand_degenerate_gains_single_baseline(key, all_calibration_parameters, transformed_antpos, use_cross_pol=RUN_CROSS_POL_PHASE_CAL)
                    data_cal = cross_pol_data[key] / gain
                    zsquare = np.abs(data_cal - cross_pol_model[key]) ** 2 / noise_var
                    excess_var_cal += zsquare * (wgts[key])# * (~where_inpainted[pol]).astype(float))
                    weights += (wgts[key])# * (~where_inpainted[pol]).astype(float))
        
            im = axs[pi + 2, 0].imshow(
                np.real(np.abs(excess_var) / weights), 
                aspect='auto', 
                interpolation='None', 
                cmap='turbo', 
                vmin=0.5, 
                vmax=10, 
                extent=extent
            )
            im = axs[pi + 2, 1].imshow(
                np.real(np.abs(excess_var_cal) / weights), 
                aspect='auto', 
                interpolation='None', 
                cmap='turbo', 
                vmin=0.5, 
                vmax=10, 
                extent=extent
            )
    
    # Labeling
    for i, pol in enumerate(polarizations):
        axs[i, 0].set_ylabel(pol)   

    if RUN_CROSS_POL_PHASE_CAL:
        axs[3, 0].set_xlabel(r"Frequency (MHz)", fontsize=12)
        axs[3, 1].set_xlabel(r"Frequency (MHz)", fontsize=12)
    else:
        axs[1, 0].set_xlabel(r"Frequency (MHz)", fontsize=12)
        axs[1, 1].set_xlabel(r"Frequency (MHz)", fontsize=12)
        
    axs[0, 0].set_title("Pre-LST Calibration", fontsize=14)
    axs[0, 1].set_title("Post-LST Calibration", fontsize=14)
    plt.tight_layout()
    cbar = plt.colorbar(im, ax=axs)
    cbar.set_label(r"Excess Variance $[|V^{\rm night} - V^{\rm LST}|^2 / V^2_N]$", fontsize=14)
    fig.text(-0.02, 0.5, 'LST (hr)', va='center', rotation='vertical', fontsize=14)

Figure 1: Amplitude Parameters Before/After Smoothing¶

In [14]:
# Get the local times for plotting
if RUN_AMPLITUDE_CAL and RUN_CALIBRATION:
    plot_amplitude_degeneracy()
    #plt.ylim([14, 16])
    #plt.xlim([100, 130])
No description has been provided for this image

Figure 2: Tip/Tilt Parameters Before/After Smoothing¶

In [15]:
if RUN_TIP_TILT_PHASE_CAL and RUN_CALIBRATION:
    plot_tip_tilt_degeneracy()
No description has been provided for this image

Figure 3: Cross-Polarized Phase Before/After Smoothing¶

In [16]:
if RUN_CROSS_POL_PHASE_CAL and RUN_CALIBRATION:
    plot_cross_polarized_phase_degeneracy()
No description has been provided for this image

Figure 4: Visibility/LST-Averaged Variance Across Baseline Before/After LST-Cal¶

In [17]:
if RUN_CALIBRATION:
    plot_excess_variance()
divide by zero encountered in divide
invalid value encountered in divide
invalid value encountered in divide
divide by zero encountered in divide
invalid value encountered in divide
invalid value encountered in divide
No description has been provided for this image

4. Save Smoothed Results¶

In [18]:
if RUN_CALIBRATION:
    indices = np.searchsorted(single_jd_times, times)
    
    # Expand out tip/tilt to full data size
    expanded_tip_tilt = {
        pol: np.zeros((single_jd_times.size,) + smoothed_tip_tilt[pol].shape[1:])
        for pol in smoothed_tip_tilt
    }
    expanded_amplitude = {
        pol: np.ones((single_jd_times.size,) + smoothed_amplitude[pol].shape[1:])
        for pol in smoothed_amplitude
    }
    for pol in expanded_tip_tilt:
        expanded_tip_tilt[pol][indices] = smoothed_tip_tilt[pol]
        expanded_amplitude[pol][indices] = np.where(
            np.isclose(smoothed_amplitude[pol], 0.0),
            1.0,
            smoothed_amplitude[pol]
        )
    
    # Expand out cross-polarized degeneracy to full data size
    expanded_cross_pol = np.zeros((single_jd_times.size,) + cross_pol_smoothed.shape[1:])
    expanded_cross_pol[indices] = cross_pol_smoothed
    
    # Expand out flags to full data size
    expanded_flags = {
        pol: np.ones((single_jd_times.size,) + all_flagged[pol].shape[1:], dtype=bool)
        for pol in all_flagged
    }
    for pol in expanded_flags:
        expanded_flags[pol][indices] = all_flagged[pol]
Casting complex values to real discards the imaginary part
In [19]:
if RUN_CALIBRATION:
    # Store calibration parameters
    all_calibration_parameters = {
        "A_Jee": expanded_amplitude['ee'],
        "A_Jnn": expanded_amplitude['nn'],
        "T_Jee": expanded_tip_tilt['ee'],
        "T_Jnn": expanded_tip_tilt['nn'],
        "cross_pol": expanded_cross_pol,
    }
    
    # Get the calibration filename
    cal_fname = os.path.join(OUTDIR, LSTCAL_FNAME_FORMAT.format(night=jd_here))

    # Write LST-cal solutions to disk
    lst_stack.calibration.write_single_baseline_lstcal_solutions(
        filename=cal_fname, 
        all_calibration_parameters=all_calibration_parameters, 
        flags=expanded_flags,
        transformed_antpos=transformed_antpos, 
        antpos=hd.antpos,
        times=single_jd_times, 
        freqs=freqs, 
        pols=polarizations
    )
In [20]:
for repo in ['hera_cal', 'hera_qm', 'hera_filters', 'hera_notebook_templates', 'pyuvdata', 'numpy']:
    exec(f'from {repo} import __version__')
    print(f'{repo}: {__version__}')
hera_cal: 3.7.8.dev38+g2607f380d
hera_qm: 2.2.1.dev12+g95ecc30f0
hera_filters: 0.1.9.dev3+ged4deb46a
hera_notebook_templates: 0.0.1.dev1479+g101a39c6e
pyuvdata: 3.2.7.dev8+g5fca0c330
numpy: 2.3.5
In [21]:
print(f'Finished execution in {(time.time() - tstart) / 60:.2f} minutes.')
Finished execution in 13.63 minutes.
In [ ]: