Single Night LST-Binned Calibration Notebook¶
by Tyler Cox and Josh Dillon, last updated on April 14th, 2026
This notebook performs LST-binned calibration (or LST-cal) on whole-JD, single-baseline, all polarization files. Most parameters are controlled by a toml config file, such as this one. In addition to single-baseline files, this notebook also requires UVFlag-compatible where_inpainted files which tell us where inpainting was previously done.
To keep the total memory footprint of the notebook reasonable, the full list of baselines are first downselected to a subset of the Nbls redundant baseline types that have the largest number of nsamples, and then compute the redundant-calibration degenerate parameters (namely, the per-frequency/time amplitude, tip-tilt, and cross-polarized phase degeneracies) that bring a single night into better alignment with the LST-average. The calibration parameters are then smoothed in time and frequency with DPSS basis functions given a user specified time and frequency smoothing scale.
Here's a set of links to skip to particular figures and tables:
• Figure 1: Amplitude Parameters Before/After Smoothing¶
• Figure 2: Tip/Tilt Parameters Before/After Smoothing¶
• Figure 3: Cross-Polarized Phase Before/After Smoothing¶
• Figure 4: Visibility/LST-Averaged Variance Across Baseline Before/After LST-Cal¶
import time
tstart = time.time()
!hostname
!date
herapost007
Tue May 12 04:46:37 MDT 2026
import jax
jax.config.update('jax_platform_name', 'cpu') # Force jax to use CPU if GPU available
import re
import os
import yaml
import glob
import toml
import numpy as np
import pylab as plt
from copy import deepcopy
from functools import reduce
from tqdm.notebook import tqdm
from concurrent.futures import ProcessPoolExecutor, as_completed
from pyuvdata import UVData
from hera_cal.lst_stack import LSTConfig
from hera_cal import lst_stack, io, flag_utils, abscal, datacontainer, redcal, utils, smooth_cal, red_groups
from hera_qm.time_series_metrics import true_stretches
from hera_filters.dspec import fourier_filter, dpss_operator
from hera_cal.lst_stack.calibration import _expand_degeneracies_to_ant_gains
from hera_cal.lst_stack.config import LSTBinConfiguratorSingleBaseline, make_lst_grid
from hera_cal.lst_stack.binning import SingleBaselineStacker, _get_freqs_chans, adjust_lst_bin_edges, _allocate_dfn, get_lst_bins
from hera_qm.metrics_io import read_a_priori_ant_flags
import warnings
warnings.filterwarnings("ignore", module="hera_cal")
%load_ext line_profiler
import jax
jax.config.update('jax_platform_name', 'cpu') # Force jax to use CPU if GPU available
toml_file = os.environ.get(
'TOML_FILE', '/lustre/aoc/projects/hera/h6c-analysis/IDR3/src/hera_pipelines/pipelines/h6c/idr3/v1/lstbin/single_bl_lst_stack.toml'
)
print(f'toml_file = "{toml_file}"')
jd_here = os.environ.get('JD', None)
if jd_here is not None:
jd_here = int(jd_here)
print(f'jd_here = {jd_here}')
toml_file = "/lustre/aoc/projects/hera/h6c-analysis/IDR3/src/hera_pipelines/pipelines/h6c/idr3/v1/lstbin/lststack_131_nights/single_bl_lst_stack.toml" jd_here = 2459979
# get options from toml file, print them out, and update globals
toml_options = toml.load(toml_file)
print(f"Now setting the following global variables from {toml_file}:\n")
globals().update({'lst_branch_cut': toml_options['FILE_CFG']['lst_branch_cut']})
print(f"lst_branch_cut = {lst_branch_cut}")
globals().update({'where_inpainted_file_rules': toml_options['FILE_CFG']['where_inpainted_file_rules']})
print(f"where_inpainted_file_rules = {where_inpainted_file_rules}")
# this is used for an initial stacking of a handful of baselines, which are then used for LSTCal
toml_options['LST_STACK_OPTS']['FNAME_FORMAT'] = toml_options['LST_STACK_OPTS']['FNAME_FORMAT'].replace('.sum.uvh5', '.preliminary.sum.uvh5')
for key, val in toml_options['LSTCAL_OPTS'].items():
if isinstance(val, str):
print(f'{key} = "{val}"')
else:
print(f'{key} = {val}')
globals().update(toml_options['LSTCAL_OPTS'])
for key, val in toml_options['LST_STACK_OPTS'].items():
if isinstance(val, str):
print(f'{key} = "{val}"')
else:
print(f'{key} = {val}')
globals().update(toml_options['LST_STACK_OPTS'])
Now setting the following global variables from /lustre/aoc/projects/hera/h6c-analysis/IDR3/src/hera_pipelines/pipelines/h6c/idr3/v1/lstbin/lststack_131_nights/single_bl_lst_stack.toml:
lst_branch_cut = 5.2
where_inpainted_file_rules = [['.reinpainted.uvh5', '.where_reinpainted.h5']]
NBLS_FOR_LSTCAL = 30
RUN_AMPLITUDE_CAL = True
RUN_TIP_TILT_PHASE_CAL = True
RUN_CROSS_POL_PHASE_CAL = True
INCLUDE_AUTOS = False
FREQ_SMOOTHING_SCALE = 30.0
TIME_SMOOTHING_SCALE = 2500
BLACKLIST_TIMESCALE_FACTOR = 10
BLACKLIST_RELATIVE_ERROR_THRESH = 0.1
BLACKLIST_NITER = 5
WHERE_INPAINTED_WGTS = 0.0001
LSTCAL_FNAME_FORMAT = "single_bl_reinpaint_1000ns/zen.{night}.lstcal.hdf5"
MAX_BL_LENGTH = 100
MIN_AVG_REDUNDANCY = 0.0
OUTDIR = "/lustre/aoc/projects/hera/h6c-analysis/IDR3/lststack-outputs-131-nights"
FNAME_FORMAT = "single_bl_reinpaint_1000ns/zen.LST.baseline.{bl_str}.preliminary.sum.uvh5"
USE_LSTCAL_GAINS = True
FM_LOW_FREQ = 87.5
FM_HIGH_FREQ = 108.0
EIGENVAL_CUTOFF = 1e-12
INPAINT_DELAY = 1000
AUTO_INPAINT_DELAY = 100
REINPAINT_AUTOS = False
INPAINT_WIDTH_FACTOR = 0.5
INPAINT_ZERO_DIST_WEIGHT = 0.01
MIN_NIGHTS_PER_BIN = 3
MAX_CHANNEL_NIGHTLY_FLAG_FRAC = 0.25
MOD_Z_TO_REINPAINT = 5
AUTO_FR_SPECTRUM_FILE = "/lustre/aoc/projects/hera/zmartino/hera_frf/spectra_cache/spectra_cache_hera_auto.h5"
GAUSS_FIT_BUFFER_CUT = 1e-05
CG_TOL = 1e-06
CG_ITER_LIM = 500
FR0_FILTER = True
FR0_HALFWIDTH = 0.01
FR0_FILT_AUTOS = False
# Baseline configurator
configurator = lst_stack.config.LSTBinConfiguratorSingleBaseline.from_toml(toml_file)
auto_baseline_string = [s for s in configurator.bl_to_file_map if (p := s.split('_'))[0] == p[1]][0]
# Get metadata for LST-stacking
hd = io.HERAData(
configurator.bl_to_file_map[auto_baseline_string][-1]
)
df = np.median(np.diff(hd.freqs))
dlst = np.median(np.diff(hd.lsts))
lst_grid = lst_stack.config.make_lst_grid(dlst, begin_lst=0, lst_width=(2 * np.pi)) # _fix_dlst function making the grid the wrong size
lst_bin_edges = np.concatenate([lst_grid - dlst / 2, (lst_grid[-1] + dlst / 2)[None]])
# Julian dates for LST-calibration
jds = [int(night) for night in configurator.nights]
filepath = toml_options['FILE_CFG']['datafiles']['datadir']
aposteriori_yamls = {jd: filepath + f'/{jd}/{jd}_aposteriori_flags.yaml' for jd in jds}
1. Load Single Night Data and Rephase to Correct LST-bins¶
# Get baseline keys from the files that have been previously saved by the stacking notebook
single_bl_files = glob.glob(
os.path.join(OUTDIR, FNAME_FORMAT.format(bl_str="*"))
)
pattern = '^' + re.escape(os.path.join(OUTDIR, FNAME_FORMAT)).replace(r'\{bl_str\}', r'(?P<a>\d+)_(?P<b>\d+)') + '$'
rx = re.compile(pattern)
baselines = []
baseline_strings = []
for file in single_bl_files:
match = rx.match(file)
i, j = map(int, match.groups())
baselines.append((i, j))
baseline_strings.append("{}_{}".format(i, j))
RUN_CALIBRATION = jd_here is not None
if RUN_CALIBRATION:
# Pre-compute rounding factor once — only depends on dlst, not per-baseline
lst_grid_rounding_factor = np.abs(np.floor(np.log10(dlst) - 2)).astype(int)
single_day_config = deepcopy(configurator)
single_day_config.nights = [str(jd_here)]
single_day_config.bl_to_file_map = {
bl_str: [file for file in configurator.bl_to_file_map[bl_str] if str(jd_here) in file]
for bl_str in configurator.bl_to_file_map
}
def process_baseline(bl_string):
crosses = SingleBaselineStacker.from_configurator(
single_day_config,
bl_string,
lst_bin_edges,
lst_branch_cut=lst_branch_cut,
where_inpainted_file_rules=where_inpainted_file_rules
)
antpair = crosses.hd.antpairs[0]
bl_to_load = os.path.join(OUTDIR, FNAME_FORMAT.format(bl_str=bl_string))
polarizations = crosses.hd.pols if RUN_CROSS_POL_PHASE_CAL else crosses.hd.pols[:2]
hd = io.HERAData(bl_to_load)
single_bl_stacked_data, lst_avg_flags, _ = hd.read(polarizations=polarizations)
lst_grid = hd.lsts.copy()
lst_grid[lst_grid[0] > lst_grid] += 2 * np.pi
indices = np.searchsorted(
np.round(lst_grid, lst_grid_rounding_factor),
np.round(crosses.bin_lst, lst_grid_rounding_factor)
)
# Pre-compute once outside the pol loop — same for all polarizations
times_per_bin = [len(t) for t in crosses.times_in_bins]
local_wgts = {}
local_data_for_cal = {}
local_model = {}
local_where_inpainted = {}
local_all_flagged = {}
local_cross_pol_data = {}
local_cross_pol_model = {}
for pi, pol in enumerate(polarizations):
# Consolidate all concatenations upfront for this polarization
nsamples = np.concatenate([n[..., pi] for n in crosses.nsamples], axis=0)
flags = np.concatenate([f[..., pi] for f in crosses.flags], axis=0)
winp_stack = np.concatenate([w[..., pi] for w in crosses.where_inpainted], axis=0)
data_stack = np.concatenate([d[..., pi] for d in crosses.data], axis=0)
model_flags = np.concatenate([
np.repeat(lst_avg_flags[antpair + (pol,)][[idx]], len(tinb), axis=0)
for idx, tinb in zip(indices, crosses.times_in_bins)
], axis=0)
flags |= model_flags
local_wgts[antpair + (pol,)] = nsamples * (~flags).astype(float)
# Bug fix: was `if pol in flags` (ndarray), should be checking the accumulator dict
if pol in local_all_flagged:
local_all_flagged[pol] &= flags
else:
local_all_flagged[pol] = flags
if pol in local_where_inpainted:
local_where_inpainted[pol] &= winp_stack
else:
local_where_inpainted[pol] = winp_stack
# Use np.repeat instead of broadcasting with np.ones to avoid allocating an intermediate array
p1, p2 = utils.split_pol(pol)
if p1 == p2:
local_data_for_cal[antpair + (pol,)] = data_stack
local_model[antpair + (pol,)] = np.repeat(
single_bl_stacked_data[antpair + (pol,)][indices], times_per_bin, axis=0
)
else:
local_cross_pol_data[antpair + (pol,)] = data_stack
local_cross_pol_model[antpair + (pol,)] = np.repeat(
single_bl_stacked_data[antpair + (pol,)][indices], times_per_bin, axis=0
)
return {
'wgts': local_wgts,
'data_for_cal': local_data_for_cal,
'model': local_model,
'where_inpainted': local_where_inpainted,
'all_flagged': local_all_flagged,
'cross_pol_data': local_cross_pol_data,
'cross_pol_model': local_cross_pol_model,
'freqs': crosses.hd.freqs,
'times': np.concatenate(crosses.times_in_bins),
}
# Initialize output dicts
data_for_cal = {}
wgts = {}
model = {}
where_inpainted = {}
all_flagged = {}
freqs = None
times = None
if RUN_CROSS_POL_PHASE_CAL:
cross_pol_data = {}
cross_pol_model = {}
# Parallelize across baselines, merging each result as it arrives to keep memory usage low
with ProcessPoolExecutor(max_workers=os.cpu_count()) as executor:
futures = {executor.submit(process_baseline, bl): bl for bl in baseline_strings}
for future in tqdm(as_completed(futures), total=len(baseline_strings)):
result = future.result()
# Merge immediately and drop the reference so memory can be freed
wgts.update(result['wgts'])
data_for_cal.update(result['data_for_cal'])
model.update(result['model'])
for pol, arr in result['where_inpainted'].items():
if pol in where_inpainted:
where_inpainted[pol] &= arr
else:
where_inpainted[pol] = arr
for pol, arr in result['all_flagged'].items():
if pol in all_flagged:
all_flagged[pol] &= arr
else:
all_flagged[pol] = arr
if RUN_CROSS_POL_PHASE_CAL:
cross_pol_data.update(result['cross_pol_data'])
cross_pol_model.update(result['cross_pol_model'])
# Grab freqs/times from first completed result
if freqs is None:
freqs = result['freqs']
times = result['times']
del result # Explicitly drop to help GC
# Explicitly use a known baseline key rather than relying on loop-end state
single_jd_hd = io.HERAData(single_day_config.bl_to_file_map[baseline_strings[0]])
single_jd_times = single_jd_hd.times
else:
print("Not running calibration since JD is not set.")
%%time
if RUN_CALIBRATION:
autos = SingleBaselineStacker.from_configurator(
single_day_config,
auto_baseline_string,
lst_bin_edges,
lst_branch_cut=lst_branch_cut,
where_inpainted_file_rules=where_inpainted_file_rules
)
# Get antennas that make up the baseline
antpair = autos.hd.antpairs[0]
polarizations = autos.hd.pols if RUN_CROSS_POL_PHASE_CAL else autos.hd.pols[:2]
auto_model = {}
for pi, pol in enumerate(polarizations):
auto_model[antpair + (pol,)] = np.concatenate([_data[..., pi] for _data in autos.data], axis=0)
p1, p2 = utils.split_pol(pol)
if INCLUDE_AUTOS and p1 == p2:
data_for_cal[antpair + (pol,)] = np.concatenate([_data[..., pi] for _data in autos.data], axis=0)
nsamples = np.concatenate([_nsamples[..., pi] for _nsamples in autos.nsamples], axis=0)
flags = np.concatenate([_flags[..., pi] for _flags in autos.flags], axis=0)
wgts[antpair + (pol,)] = nsamples * (~flags).astype(float)
if INCLUDE_AUTOS:
autos = SingleBaselineStacker.from_configurator(
configurator,
auto_baseline_string,
lst_bin_edges,
lst_branch_cut=lst_branch_cut,
where_inpainted_file_rules=where_inpainted_file_rules
)
lst_avg_autos, _, _ = autos.average_over_nights()
antpair = autos.hd.antpairs[0]
for pidx, pol in enumerate(polarizations):
p1, p2 = utils.split_pol(pol)
if p1 == p2:
model[antpair + (pol,)] = np.concatenate([
lst_avg_autos[[idx], :, pidx] * np.ones((len(tinb), 1))
for idx, tinb in zip(indices, crosses.times_in_bins)
], axis=0)
CPU times: user 3.88 s, sys: 1.49 s, total: 5.36 s Wall time: 5.77 s
2. LST-Binned Calibration¶
import jax
from jax import numpy as jnp
def complex_phase_abscal(data, model, reds, data_bls, model_bls, transformed_antpos=None, newton_maxiter=20):
"""
Stripped down version of hera_cal.abscal.complex_phase_abscal that assumes the tip-tilt solution is close to zero.
Calculates gains that would absolute calibrate the phase of already redundantly-calibrated data.
Only operates one polarization at a time.
Parameters:
----------
data : DataContainer or RedDataContainer
Dictionary-like container mapping baselines to data visibilities to abscal
model : DataContainer or RedDataContainer
Dictionary-like container mapping baselines to model visibilities
reds : list of lists
List of lists of redundant baselines tuples like (0, 1, 'ee'). Ignored if transformed_antpos is not None.
data_bls : list of tuples
List of baseline tuples in data to use.
model_bls : list of tuples
List of baseline tuples in model to use. Must correspond the same physical separations as data_bls.
transformed_antpos : dict
Dictionary of abstracted antenna positions that you'd normally get from redcal.reds_to_antpos().
If None, will be inferred from reds.
Returns:
-------
meta : dictionary
Contains keys for
'Lambda_sol' : phase gradient solutions,
'Z_sol' : value of the objective function at the solution,
'newton_iterations' : number of iterations completed by the Newton's method solver
delta_gains : dictionary
Dictionary mapping antenna keys like (0, 'Jee') to gains of the same shape of the data
"""
# Check that baselines selected are for the same polarization
pols = list(set([bl[2] for bls in (data_bls, model_bls) for bl in bls]))
assert len(pols) == 1, 'complex_phase_abscal() can only solve for one polarization at a time.'
# Get transformed antenna positions and baselines
if transformed_antpos is None:
transformed_antpos = redcal.reds_to_antpos(reds)
abscal._put_transformed_array_on_integer_grid(transformed_antpos)
transformed_b_vecs = np.rint([transformed_antpos[jj] - transformed_antpos[ii] for (ii, jj, pol) in data_bls]).astype(int)
# Get number of baselines and times/freqs
Ngroups = len(data_bls)
Ntimes, Nfreqs = data[data_bls[0]].shape
# Build up array of Fourier coefficients of the objective function
Z_coefficients = np.zeros((Ntimes, Nfreqs, Ngroups), dtype=complex)
for nn in range(Ngroups):
Vhat_n = data[data_bls[nn]]
Vbar_n = model[model_bls[nn]]
Z_coefficients[:, :, nn] = Vhat_n * np.conj(Vbar_n)
# Get solution for degenerate phase gradient vectors
Ntimes, Nfreqs, Ngroups = Z_coefficients.shape
Ndims = transformed_b_vecs.shape[1]
Lambda_sol = np.zeros((Ntimes, Nfreqs, Ndims), dtype=float)
batched_newton = jax.vmap(
lambda z_coeffs, lambda_init: abscal._newton_solve(
lambda_init, transformed_b_vecs, z_coeffs, tol=1e-8, maxiter=newton_maxiter
),
in_axes=(0, 0)
)
Z_coeffs_flat = jnp.array(Z_coefficients.reshape(Ntimes * Nfreqs, Ngroups))
Lambda_inits_flat = jnp.array(Lambda_sol.reshape(Ntimes * Nfreqs, Ndims))
Lambda_flat, niters_flat = batched_newton(Z_coeffs_flat, Lambda_inits_flat)
_ = niters_flat.block_until_ready()
Lambda_sol = np.array(Lambda_flat).reshape(Ntimes, Nfreqs, Ndims)
newton_iterations = np.array(niters_flat).reshape(Ntimes, Nfreqs)
# --- vmap _eval_Z for Z_sol ---
batched_eval_Z = jax.vmap(
lambda z_coeffs, lam: abscal._eval_Z(lam, transformed_b_vecs, z_coeffs),
in_axes=(0, 0)
)
Z_sol = np.array(
batched_eval_Z(Z_coeffs_flat, Lambda_flat)
).reshape(Ntimes, Nfreqs)
# turn solution into per-antenna gains
meta = {
'Lambda_sol': -Lambda_sol,
'transformed_antpos': transformed_antpos,
'Z_sol': Z_sol,
}
return meta
%%time
if RUN_CALIBRATION:
# Amplitude Calibration
if RUN_AMPLITUDE_CAL:
amplitude_solutions = abscal.abs_amp_lincal(
model=model,
data=data_for_cal,
wgts=wgts,
verbose=False,
)
else:
amplitude_solutions = {}
data_shape = data_for_cal[list(data_for_cal.keys())[0]].shape
amplitude_solutions['ee'] = np.ones(data_shape)
amplitude_solutions['nn'] = np.ones(data_shape)
# Tip-tilt Phase Calibration
if RUN_TIP_TILT_PHASE_CAL:
# Get the redundancies
all_reds = red_groups.RedundantGroups.from_antpos(
antpos=hd.antpos,
pols=('nn', 'ee'),
include_autos=False
)
# Fit the tip-tilt for both pols
phase_solutions = {}
for pol in ['ee', 'nn']:
phase_fit = complex_phase_abscal(
{k: data_for_cal[k][:] for k in data_for_cal if k[-1] == pol and k[0] != k[1]},
{k: model[k][:] for k in model if k[-1] == pol and k[0] != k[1]},
all_reds,
[k for k in data_for_cal if k[-1] == pol and k[0] != k[1]],
[k for k in model if k[-1] == pol and k[0] != k[1]],
)
phase_solutions[pol] = phase_fit['Lambda_sol']
transformed_antpos = phase_fit['transformed_antpos']
else:
data_shape = data_for_cal[list(data_for_cal.keys())[0]].shape
phase_solutions = {}
for pol in ['ee', 'nn']:
phase_fit = np.zeros(data_shape + (2,))
phase_solutions[pol] = phase_fit
transformed_antpos = {k: hd.antpos[k][:2] for k in hd.antpos}
if RUN_CROSS_POL_PHASE_CAL:
cross_pol_phase = abscal.cross_pol_phase_cal(
model=cross_pol_model,
data=cross_pol_data,
model_bls=list(cross_pol_model.keys()),
data_bls=list(cross_pol_data.keys()),
wgts=wgts,
)
else:
data_shape = data_for_cal[list(data_for_cal.keys())[0]].shape
cross_pol_phase = np.zeros(data_shape)
CPU times: user 4min 28s, sys: 29.5 s, total: 4min 58s Wall time: 2min 56s
if RUN_CALIBRATION:
# Check blacklisting in amplitude calibration (harder to do in phase)
blacklist_wgts = {}
avg_wgts = {
pol: np.mean([wgts[key] for key in wgts if pol in key], axis=0)
for pol in polarizations
}
for pol in ['ee', 'nn']:
if RUN_AMPLITUDE_CAL:
gains = np.where(
np.isfinite(amplitude_solutions[f"A_J{pol}"]),
amplitude_solutions[f"A_J{pol}"],
1.0
)
blacklist_wgts_here = np.ones(gains.shape,)
for i in range(BLACKLIST_NITER):
wgts_here = np.where(np.isfinite(amplitude_solutions[f"A_J{pol}"]), avg_wgts[pol], 0.0)
wgts_here = np.where(all_flagged[pol], 0.0, wgts_here)
wgts_here = np.where(where_inpainted[pol], WHERE_INPAINTED_WGTS, wgts_here)
wgts_here *= blacklist_wgts_here
smoothed_amp, _ = smooth_cal.time_freq_2D_filter(
gains=gains.astype(complex),
wgts=wgts_here,
freqs=freqs,
times=times,
freq_scale=FREQ_SMOOTHING_SCALE,
time_scale=TIME_SMOOTHING_SCALE * BLACKLIST_TIMESCALE_FACTOR,
eigenval_cutoff=EIGENVAL_CUTOFF,
method='DPSS',
use_sparse_solver=True,
precondition_solver=True,
fix_phase_flips=False,
flag_phase_flip_ints=False,
skip_flagged_edges=True,
freq_cuts=[(FM_LOW_FREQ + FM_HIGH_FREQ) * .5e6],
)
blacklist_wgts_here = np.where(
(np.abs(gains - smoothed_amp) / np.abs(smoothed_amp)) > BLACKLIST_RELATIVE_ERROR_THRESH,
0.0,
1.0
)
blacklist_wgts[pol] = blacklist_wgts_here.copy()
else:
blacklist_wgts[pol] = np.ones_like(amplitude_solutions[f"A_J{pol}"])
Mean of empty slice
Mean of empty slice
invalid value encountered in divide
Mean of empty slice
Mean of empty slice
invalid value encountered in divide
Mean of empty slice
Mean of empty slice
invalid value encountered in divide
Mean of empty slice
Mean of empty slice
invalid value encountered in divide
Mean of empty slice
Mean of empty slice
invalid value encountered in divide
Mean of empty slice
Mean of empty slice
invalid value encountered in divide
Mean of empty slice
Mean of empty slice
invalid value encountered in divide
Mean of empty slice
Mean of empty slice
invalid value encountered in divide
Mean of empty slice
Mean of empty slice
invalid value encountered in divide
Mean of empty slice
Mean of empty slice
invalid value encountered in divide
if RUN_CALIBRATION:
smoothed_tip_tilt = {}
smoothed_amplitude = {}
for pol in ['ee', 'nn']:
if RUN_AMPLITUDE_CAL:
gains = np.where(
np.isfinite(amplitude_solutions[f"A_J{pol}"]),
amplitude_solutions[f"A_J{pol}"],
1.0
)
wgts_here = np.where(np.isfinite(amplitude_solutions[f"A_J{pol}"]), avg_wgts[pol], 0.0)
wgts_here = np.where(all_flagged[pol], 0.0, wgts_here)
wgts_here = np.where(where_inpainted[pol], WHERE_INPAINTED_WGTS, wgts_here)
wgts_here *= blacklist_wgts[pol]
smoothed_amp, _ = smooth_cal.time_freq_2D_filter(
gains=gains.astype(complex),
wgts=wgts_here,
freqs=freqs,
times=times,
freq_scale=FREQ_SMOOTHING_SCALE,
time_scale=TIME_SMOOTHING_SCALE,
eigenval_cutoff=EIGENVAL_CUTOFF,
method='DPSS',
use_sparse_solver=True,
precondition_solver=True,
fix_phase_flips=False,
flag_phase_flip_ints=False,
skip_flagged_edges=True,
freq_cuts=[(FM_LOW_FREQ + FM_HIGH_FREQ) * .5e6],
)
smoothed_amplitude[pol] = np.where(
all_flagged[pol], 1.0, smoothed_amp.real
)
else:
smoothed_amplitude[pol] = np.ones_like(amplitude_solutions[pol])
if RUN_TIP_TILT_PHASE_CAL:
smoothed_solutions = []
for i in range(2):
tip_tilt = phase_solutions[pol][..., i].astype(complex)
gains = np.where(np.isfinite(tip_tilt), tip_tilt, 0.0)
wgts_here = np.where(np.isfinite(tip_tilt), avg_wgts[pol], 0.0)
wgts_here = np.where(all_flagged[pol], 0.0, wgts_here)
wgts_here = np.where(where_inpainted[pol], WHERE_INPAINTED_WGTS, wgts_here)
wgts_here *= blacklist_wgts[pol]
tip_tilt_smoothed, _ = smooth_cal.time_freq_2D_filter(
gains=gains,
wgts=wgts_here,
freqs=freqs,
times=times,
freq_scale=FREQ_SMOOTHING_SCALE,
time_scale=TIME_SMOOTHING_SCALE,
eigenval_cutoff=EIGENVAL_CUTOFF,
method='DPSS',
use_sparse_solver=True,
precondition_solver=True,
fix_phase_flips=False,
flag_phase_flip_ints=False,
skip_flagged_edges=True,
freq_cuts=[(FM_LOW_FREQ + FM_HIGH_FREQ) * .5e6],
)
smoothed_solutions.append(tip_tilt_smoothed.real)
smoothed_tip_tilt[pol] = np.where(
all_flagged[pol][..., None], 0.0, np.transpose(smoothed_solutions, (1, 2, 0))
)
else:
smoothed_tip_tilt[pol] = np.zeros(wgts[list(wgts.keys())[0]].shape + (2,))
if RUN_CROSS_POL_PHASE_CAL:
gains = np.where(np.isfinite(cross_pol_phase), cross_pol_phase, 0.0)
wgts_here = np.where(np.isfinite(cross_pol_phase), avg_wgts['en'] + avg_wgts['ne'], 0.0)
wgts_here = np.where(all_flagged['ee'] | all_flagged['nn'], 0.0, wgts_here)
wgts_here = np.where(where_inpainted[pol], WHERE_INPAINTED_WGTS, wgts_here)
wgts_here *= blacklist_wgts[pol]
cross_pol_smoothed, _ = smooth_cal.time_freq_2D_filter(
gains=gains.astype(complex),
wgts=wgts_here,
freqs=freqs,
times=times,
freq_scale=FREQ_SMOOTHING_SCALE,
time_scale=TIME_SMOOTHING_SCALE,
eigenval_cutoff=EIGENVAL_CUTOFF,
method='DPSS',
use_sparse_solver=True,
precondition_solver=True,
fix_phase_flips=False,
flag_phase_flip_ints=False,
skip_flagged_edges=True,
freq_cuts=[(FM_LOW_FREQ + FM_HIGH_FREQ) * .5e6],
)
cross_pol_smoothed = np.where(
all_flagged["nn"] | all_flagged["ee"],
0.0,
cross_pol_smoothed
)
else:
cross_pol_smoothed = np.zeros(wgts[list(wgts.keys())[0]].shape)
Mean of empty slice
Mean of empty slice
Mean of empty slice
Mean of empty slice
Mean of empty slice
Mean of empty slice
Mean of empty slice
Mean of empty slice
Mean of empty slice
Mean of empty slice
Mean of empty slice
Mean of empty slice
Mean of empty slice
Mean of empty slice
def plot_amplitude_degeneracy():
ai, aj = list(map(int, baseline_strings[0].split("_")))
lsts = utils.JD2LST(times) * 12 / np.pi
if lsts[0] > lsts[-1]:
wrap_point = (lsts[0] + lsts[-1]) / 2
else:
wrap_point = (lsts[0] + lsts[-1] + 24) / 2
lsts[wrap_point < lsts] -= 24
extent = [freqs.min() / 1e6, freqs.max() / 1e6, lsts.max(), lsts.min()]
fig, axs = plt.subplots(2, 2, figsize=(15, 10), sharex=True, sharey=True)
for pi, pol in enumerate(['ee', 'nn']):
axs[pi, 0].imshow(
np.where(
where_inpainted[pol] | (all_flagged[pol]),
np.nan,
np.abs(amplitude_solutions[f'A_J{pol}'])
),
aspect='auto',
interpolation='None',
vmin=0.95,
vmax=1.05,
extent=extent,
cmap='turbo'
)
im = axs[pi, 1].imshow(
np.where(all_flagged[pol], np.nan, np.abs(smoothed_amplitude[pol])),
aspect='auto',
interpolation='None',
vmin=0.95,
vmax=1.05,
extent=extent,
cmap='turbo'
)
plt.tight_layout()
axs[1, 0].set_xlabel("Frequency (MHz)", fontsize=12)
axs[1, 1].set_xlabel("Frequency (MHz)", fontsize=12)
axs[0, 0].set_ylabel("ee", fontsize=12)
axs[1, 0].set_ylabel("nn", fontsize=12)
axs[0, 0].set_title("Raw Amplitude Degeneracy", fontsize=14)
axs[0, 1].set_title("Smoothed Amplitude Degeneracy", fontsize=14)
cbar = plt.colorbar(im, ax=axs, fraction=0.05, pad=0.01)
cbar.set_label("Gain Amplitude", fontsize=14)
fig.text(-0.02, 0.5, 'LST (hr)', va='center', rotation='vertical', fontsize=14)
def plot_tip_tilt_degeneracy():
ai, aj = list(map(int, baseline_strings[0].split("_")))
lsts = utils.JD2LST(times) * 12 / np.pi
if lsts[0] > lsts[-1]:
wrap_point = (lsts[0] + lsts[-1]) / 2
else:
wrap_point = (lsts[0] + lsts[-1] + 24) / 2
lsts[wrap_point < lsts] -= 24
extent = [freqs.min() / 1e6, freqs.max() / 1e6, lsts.max(), lsts.min()]
antvec = np.array([hd.antpos[k][:2] for k in hd.antpos])
antvec -= antvec[0]
transformed_antvec = np.array([transformed_antpos[k][:2] for k in hd.antpos])
transformed_antvec -= transformed_antvec[0]
coord_scalar = np.diag(np.linalg.solve(transformed_antvec.T.dot(transformed_antvec), transformed_antvec.T.dot(antvec)))
fig, axs = plt.subplots(4, 2, figsize=(15, 12), sharex=True, sharey=True)
ci = 0
for pi, pol in enumerate(['ee', 'nn']):
for ni in range(2):
axs[ci, 0].imshow(
np.where(
where_inpainted[pol] | (all_flagged[pol]),
np.nan,
phase_solutions[pol][..., ni] * coord_scalar[ni],
),
aspect='auto',
interpolation='None',
vmin=-0.05,
vmax=0.05,
extent=extent,
cmap='turbo'
)
tip_tilt = smoothed_tip_tilt[pol][..., ni] * coord_scalar[ni]
im = axs[ci, 1].imshow(
np.where(all_flagged[pol], np.nan, tip_tilt),
aspect='auto',
interpolation='None',
vmin=-0.05,
vmax=0.05,
extent=extent,
cmap='turbo'
)
# Labeling
axs[ci, 0].set_ylabel(f"{pol} component-{ci%2}")
ci += 1
axs[3, 0].set_xlabel(r"Frequency (MHz)", fontsize=12)
axs[3, 1].set_xlabel(r"Frequency (MHz)", fontsize=12)
axs[0, 0].set_title("Raw Tip-Tilt Degeneracy", fontsize=14)
axs[0, 1].set_title("Smoothed Tip-Tilt Degeneracy", fontsize=14)
plt.tight_layout()
cbar = plt.colorbar(im, ax=axs, fraction=0.04, pad=0.01)
cbar.set_label("Phase Gradient (rad/m)", fontsize=14)
fig.text(-0.02, 0.5, 'LST (hr)', va='center', rotation='vertical', fontsize=14)
def plot_cross_polarized_phase_degeneracy():
ai, aj = list(map(int, baseline_strings[0].split("_")))
lsts = utils.JD2LST(times) * 12 / np.pi
if lsts[0] > lsts[-1]:
wrap_point = (lsts[0] + lsts[-1]) / 2
else:
wrap_point = (lsts[0] + lsts[-1] + 24) / 2
lsts[wrap_point < lsts] -= 24
extent = [freqs.min() / 1e6, freqs.max() / 1e6, lsts.max(), lsts.min()]
fig, axs = plt.subplots(1, 2, figsize=(15, 6), sharex=True, sharey=True)
pol = 'ne'
axs[0].imshow(
np.where(
where_inpainted[pol] | (all_flagged[pol]),
np.nan,
cross_pol_phase,
),
aspect='auto',
interpolation='None',
vmin=-0.1,
vmax=0.1,
extent=extent,
cmap='coolwarm'
)
im = axs[1].imshow(
np.where(all_flagged[pol], np.nan, cross_pol_smoothed.real),
aspect='auto',
interpolation='None',
vmin=-0.1,
vmax=0.1,
extent=extent,
cmap='coolwarm'
)
# Labeling
axs[0].set_ylabel(r"LST (hr)")
axs[0].set_xlabel(r"Frequency (MHz)")
axs[1].set_xlabel(r"Frequency (MHz)")
axs[0].set_title("Raw Relative Phase Degeneracy")
axs[1].set_title("Smoothed Relative Phase Degeneracy")
plt.tight_layout()
cbar = plt.colorbar(im, ax=axs, fraction=0.05, pad=0.01)
cbar.set_label("Jee/Jnn Relative Phase (rad)", fontsize=14)
def expand_degenerate_gains_single_baseline(key, all_calibration_parameters, transformed_antpos, use_cross_pol=True):
"""
"""
ant1pol, ant2pol = utils.split_bl(key)
blvec = transformed_antpos[ant2pol[0]] - transformed_antpos[ant1pol[0]]
gain = all_calibration_parameters[f"A_{ant1pol[1]}"].astype(complex) * all_calibration_parameters[f"A_{ant1pol[1]}"].astype(complex)
g1 = np.exp(
1j * np.einsum("tfc,c->tf", all_calibration_parameters[f"T_{ant2pol[1]}"], transformed_antpos[ant2pol[0]])
)
g2 = np.exp(
1j * np.einsum("tfc,c->tf", all_calibration_parameters[f"T_{ant1pol[1]}"], transformed_antpos[ant1pol[0]])
)
gain *= g1 * g2.conj()
if use_cross_pol:
if ant1pol[-1] != ant2pol[-1]:
if ant1pol[-1] == 'Jnn':
g1 = np.exp(1j * all_calibration_parameters['cross_pol'])
gain *= g1
elif ant2pol[-1] == 'Jnn':
g1 = np.exp(-1j * all_calibration_parameters['cross_pol'])
gain *= g1
return gain
def plot_excess_variance():
ai, aj = list(map(int, baseline_strings[0].split("_")))
lsts = utils.JD2LST(times) * 12 / np.pi
if lsts[0] > lsts[-1]:
wrap_point = (lsts[0] + lsts[-1]) / 2
else:
wrap_point = (lsts[0] + lsts[-1] + 24) / 2
lsts[wrap_point < lsts] -= 24
extent = [freqs.min() / 1e6, freqs.max() / 1e6, lsts.max(), lsts.min()]
all_calibration_parameters = {
"A_Jee": smoothed_amplitude['ee'],
"A_Jnn": smoothed_amplitude['nn'],
"T_Jee": smoothed_tip_tilt['ee'],
"T_Jnn": smoothed_tip_tilt['nn'],
"cross_pol": cross_pol_smoothed,
}
if RUN_CROSS_POL_PHASE_CAL:
fig, axs = plt.subplots(4, 2, figsize=(15, 10), sharey=True, sharex=True)
else:
fig, axs = plt.subplots(2, 2, figsize=(15, 6), sharey=True, sharex=True)
for pi, pol in enumerate(['ee', 'nn']):
excess_var = 0
excess_var_cal = 0
count = 0
weights = 0
for key in data_for_cal:
if pol in key and key[0] != key[1]:
noise_var = np.abs(auto_model[(0, 0, pol)]) ** 2 / wgts[key] / 10 / 122e3
zsquare = np.abs(data_for_cal[key] - model[key]) ** 2 / noise_var
excess_var += zsquare * (wgts[key])# * (~where_inpainted[pol]).astype(float))
gain = expand_degenerate_gains_single_baseline(key, all_calibration_parameters, transformed_antpos, use_cross_pol=True)
data_cal = data_for_cal[key] / gain
zsquare = np.abs(data_cal - model[key]) ** 2 / noise_var
excess_var_cal += zsquare * (wgts[key])# * (~where_inpainted[pol]).astype(float))
weights += (wgts[key])# * (~where_inpainted[pol]).astype(float))
im = axs[pi, 0].imshow(
np.real(np.abs(excess_var) / weights),
aspect='auto',
interpolation='None',
cmap='turbo',
vmin=0.5,
vmax=10,
extent=extent
)
im = axs[pi, 1].imshow(
np.real(np.abs(excess_var_cal) / weights),
aspect='auto',
interpolation='None',
cmap='turbo',
vmin=0.5,
vmax=10,
extent=extent
)
if RUN_CROSS_POL_PHASE_CAL:
for pi, pol in enumerate(['en', 'ne']):
excess_var = 0
excess_var_cal = 0
count = 0
weights = 0
for key in cross_pol_data:
if pol in key:
noise_var = np.abs(auto_model[(0, 0, "ee")] * auto_model[(0, 0, "nn")]) / wgts[key] / 10 / 122e3
zsquare = np.abs(cross_pol_data[key] - cross_pol_model[key]) ** 2 / noise_var
excess_var += zsquare * (wgts[key])# * (~where_inpainted[pol]).astype(float))
gain = expand_degenerate_gains_single_baseline(key, all_calibration_parameters, transformed_antpos, use_cross_pol=RUN_CROSS_POL_PHASE_CAL)
data_cal = cross_pol_data[key] / gain
zsquare = np.abs(data_cal - cross_pol_model[key]) ** 2 / noise_var
excess_var_cal += zsquare * (wgts[key])# * (~where_inpainted[pol]).astype(float))
weights += (wgts[key])# * (~where_inpainted[pol]).astype(float))
im = axs[pi + 2, 0].imshow(
np.real(np.abs(excess_var) / weights),
aspect='auto',
interpolation='None',
cmap='turbo',
vmin=0.5,
vmax=10,
extent=extent
)
im = axs[pi + 2, 1].imshow(
np.real(np.abs(excess_var_cal) / weights),
aspect='auto',
interpolation='None',
cmap='turbo',
vmin=0.5,
vmax=10,
extent=extent
)
# Labeling
for i, pol in enumerate(polarizations):
axs[i, 0].set_ylabel(pol)
if RUN_CROSS_POL_PHASE_CAL:
axs[3, 0].set_xlabel(r"Frequency (MHz)", fontsize=12)
axs[3, 1].set_xlabel(r"Frequency (MHz)", fontsize=12)
else:
axs[1, 0].set_xlabel(r"Frequency (MHz)", fontsize=12)
axs[1, 1].set_xlabel(r"Frequency (MHz)", fontsize=12)
axs[0, 0].set_title("Pre-LST Calibration", fontsize=14)
axs[0, 1].set_title("Post-LST Calibration", fontsize=14)
plt.tight_layout()
cbar = plt.colorbar(im, ax=axs)
cbar.set_label(r"Excess Variance $[|V^{\rm night} - V^{\rm LST}|^2 / V^2_N]$", fontsize=14)
fig.text(-0.02, 0.5, 'LST (hr)', va='center', rotation='vertical', fontsize=14)
Figure 1: Amplitude Parameters Before/After Smoothing¶
# Get the local times for plotting
if RUN_AMPLITUDE_CAL and RUN_CALIBRATION:
plot_amplitude_degeneracy()
#plt.ylim([14, 16])
#plt.xlim([100, 130])
Figure 2: Tip/Tilt Parameters Before/After Smoothing¶
if RUN_TIP_TILT_PHASE_CAL and RUN_CALIBRATION:
plot_tip_tilt_degeneracy()
Figure 3: Cross-Polarized Phase Before/After Smoothing¶
if RUN_CROSS_POL_PHASE_CAL and RUN_CALIBRATION:
plot_cross_polarized_phase_degeneracy()
Figure 4: Visibility/LST-Averaged Variance Across Baseline Before/After LST-Cal¶
if RUN_CALIBRATION:
plot_excess_variance()
divide by zero encountered in divide
invalid value encountered in divide invalid value encountered in divide
divide by zero encountered in divide
invalid value encountered in divide invalid value encountered in divide
4. Save Smoothed Results¶
if RUN_CALIBRATION:
indices = np.searchsorted(single_jd_times, times)
# Expand out tip/tilt to full data size
expanded_tip_tilt = {
pol: np.zeros((single_jd_times.size,) + smoothed_tip_tilt[pol].shape[1:])
for pol in smoothed_tip_tilt
}
expanded_amplitude = {
pol: np.ones((single_jd_times.size,) + smoothed_amplitude[pol].shape[1:])
for pol in smoothed_amplitude
}
for pol in expanded_tip_tilt:
expanded_tip_tilt[pol][indices] = smoothed_tip_tilt[pol]
expanded_amplitude[pol][indices] = np.where(
np.isclose(smoothed_amplitude[pol], 0.0),
1.0,
smoothed_amplitude[pol]
)
# Expand out cross-polarized degeneracy to full data size
expanded_cross_pol = np.zeros((single_jd_times.size,) + cross_pol_smoothed.shape[1:])
expanded_cross_pol[indices] = cross_pol_smoothed
# Expand out flags to full data size
expanded_flags = {
pol: np.ones((single_jd_times.size,) + all_flagged[pol].shape[1:], dtype=bool)
for pol in all_flagged
}
for pol in expanded_flags:
expanded_flags[pol][indices] = all_flagged[pol]
Casting complex values to real discards the imaginary part
if RUN_CALIBRATION:
# Store calibration parameters
all_calibration_parameters = {
"A_Jee": expanded_amplitude['ee'],
"A_Jnn": expanded_amplitude['nn'],
"T_Jee": expanded_tip_tilt['ee'],
"T_Jnn": expanded_tip_tilt['nn'],
"cross_pol": expanded_cross_pol,
}
# Get the calibration filename
cal_fname = os.path.join(OUTDIR, LSTCAL_FNAME_FORMAT.format(night=jd_here))
# Write LST-cal solutions to disk
lst_stack.calibration.write_single_baseline_lstcal_solutions(
filename=cal_fname,
all_calibration_parameters=all_calibration_parameters,
flags=expanded_flags,
transformed_antpos=transformed_antpos,
antpos=hd.antpos,
times=single_jd_times,
freqs=freqs,
pols=polarizations
)
for repo in ['hera_cal', 'hera_qm', 'hera_filters', 'hera_notebook_templates', 'pyuvdata', 'numpy']:
exec(f'from {repo} import __version__')
print(f'{repo}: {__version__}')
hera_cal: 3.7.8.dev38+g2607f380d hera_qm: 2.2.1.dev12+g95ecc30f0 hera_filters: 0.1.9.dev3+ged4deb46a
hera_notebook_templates: 0.0.1.dev1479+g101a39c6e pyuvdata: 3.2.7.dev8+g5fca0c330 numpy: 2.3.5
print(f'Finished execution in {(time.time() - tstart) / 60:.2f} minutes.')
Finished execution in 9.81 minutes.