Skip to content

Commit

Permalink
add submasks, t_filter_factor, and return_arrivaltimes options.
Browse files Browse the repository at this point in the history
* submasks allows you to mask one component without affecting the others
* t_filter_factor allows a user to adjust the width of the time filter
(and is the analog of bw_filter_factor)
* return_arrivaltimes allows you to get a dataframe of the fit points
for external analysis
* update measure_allmethods with plots and options for fit p0s
  • Loading branch information
mef51 committed May 7, 2024
1 parent 6afa7d8 commit 525831a
Showing 1 changed file with 100 additions and 19 deletions.
119 changes: 100 additions & 19 deletions method_tests/arrivaltimes.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,8 +196,11 @@ def measureburst(
downfactors=(1,1),
subtractbg=False,
bw_filter_factor=3,
t_filter_factor=2,
crop=None,
masks=[],
submasks=None,
return_arrivaltimes=False,
outdir='',
show=True,
show_components=False,
Expand Down Expand Up @@ -237,9 +240,17 @@ def measureburst(
bw_filter_factor (int, optional): By default 3 $sigma$ of the burst bandwidth is applied as a spectral filter.
For bursts with lots of frequency structure this may be inadequate,
and this parameter can be used to override the filter width. It's recommended to try downsampling first.
t_filter_factor (int, optional): By default 2 $sigma$ of the burst duration is applied as a temporal filter.
outdir (str, optional): string of output folder for figures. Defaults to ''.
crop (tuple[int], optional): pair of indices to crop the waterfall in time
masks (List[int], optional): frequency indices to mask. Masks are applied before downsampling
submasks (tuple[List[int]], optional): tuple of length `xos` of lists of indices to mask on a subcomponent's waterfall.
Note that contrary to `masks`, these are applied after downsampling.
Indices are scaled from the original size to the downsampled size and so can cover more than one channel.
The length of `submasks` must match the length of `xos`.
Example: To specify a mask on the 4th component of a waterfall with 4 components, pass
``submask=([],[],[],[22])``
return_arrivaltimes (bool, optional): If True, will return a dataframe of the arrival times per channel
show (bool, optional): if True show interactive figure window for each file
show_components (bool, optional): if True show figure window for each sub-burst
save (bool, optional): if True save a figure displaying the measurements.
Expand Down Expand Up @@ -271,6 +282,8 @@ def measureburst(
'tb (ms)', # nu0dtaudnu
'tb_err'
```
arrtimesdf (pd.DataFrame): Only return when `return_arrivaltimes` is True.
"""
if type(xos) == tuple:
if len(xos) != 2:
Expand All @@ -293,6 +306,12 @@ def measureburst(
data = np.load(filename, allow_pickle=True)
wfall = np.copy(data['wfall'])

if not submasks:
submasks = ([],)*len(xos)
else:
if len(submasks) != len(xos):
raise ValueError("Please ensure the length of xos and submasks match")

if targetDM:
print(f"Info: Dedispersing from {data['DM']} to {targetDM} pc/cm3")
ddm = targetDM - data['DM']
Expand Down Expand Up @@ -362,7 +381,7 @@ def measureburst(
xos.append(pktime)

## Assuming 1 burst:
window = 2*abs(t_popt[2]) # 2*stddev of a guassian fit to the integrated time series
# window = t_filter_factor*abs(t_popt[2]) # 2*stddev of a guassian fit to the integrated time series

##### multi component model: use multiple 1d gaussians to make multiple windows in time,
# then use the time windows to make frequency windows
Expand Down Expand Up @@ -449,9 +468,15 @@ def measureburst(
'darkgreen',
'brown'
])
for subfall, subband, xosi, sigma, sigma_err in zip(
subfalls, subbands, xos, tmix_sigmas, tmix_sigma_errs

for subfall, subband, xosi, sigma, sigma_err, submask in zip(
subfalls, subbands, xos, tmix_sigmas, tmix_sigma_errs, submasks
):
for m in submask:
if type(m) == range:
m = np.array(m)
subfall[m//downfactors[0]] = 0

sigma = abs(sigma)
subdf = fitrows(subfall, res_time_ms, freqs) # Fit a 1d gaussian in each row of the waterfall
if len(cuts) == 0:
Expand Down Expand Up @@ -488,7 +513,7 @@ def measureburst(
printd(f"Debug: pre-filters {len(subdf) = }")
subdf = subdf[(subdf.amp > 0)]
subdf = subdf[subdf.tstart_err/subdf.tstart < 10]
subdf = subdf[(subpktime-2*sigma < subdf[tpoint]) & (subdf[tpoint] < subpktime+2*sigma)]
subdf = subdf[(subpktime-t_filter_factor*sigma < subdf[tpoint]) & (subdf[tpoint] < subpktime+t_filter_factor*sigma)]
printd(f"Debug: post-filters {len(subdf) = }")

if bwidth != 1:
Expand Down Expand Up @@ -658,7 +683,7 @@ def measureburst(
ax_wfall.plot(
times_ms-pktime,
(1/dtdnu)*(times_ms-xoi) - tb/dtdnu,
'w--',
'w-.',
alpha=0.75,
# label=f'$dt/d\\nu = $ {dtdnu:.2e} $\\pm$ {dtdnu_err:.2e}'
label=f'{subburst_suffixes[xos.index(xoi)]}. $dt/d\\nu =$ {scilabel(dtdnu, dtdnu_err)} ms/MHz'
Expand All @@ -673,7 +698,7 @@ def measureburst(
# plot filter windows (time)
sp = 0
for s, xoi in zip(tmix_sigmas, xos):
w = 2*np.abs(s)
w = t_filter_factor*np.abs(s)
ax_tseries.add_patch(Rectangle(
(xoi-pktime-w, ax_tseries.get_ylim()[0] + sp*(np.max(tseries)*0.075)),
width=2*w,
Expand Down Expand Up @@ -819,9 +844,12 @@ def printinfo(event):
print(f"Info: Saved {outname}.")

plt.close()
return results
if return_arrivaltimes:
return results, subdf
else:
return results

def measure_allmethods(filename, show=True, **kwargs):
def measure_allmethods(filename, show=True, p0tw=0.01, p0bw=100, **kwargs):
"""
Collect spectro-temporal measurements of a burst using multiple techniques.
Utility for comparing the result of spectro-temporal measurements of a burst obtained from the following
Expand Down Expand Up @@ -863,15 +891,28 @@ def measure_allmethods(filename, show=True, **kwargs):
aspect='auto',
origin='lower',
extent=extent,
interpolation='none'
interpolation='none',
norm='linear',
vmax=np.quantile(wfall, 0.999),
)
axs[0].annotate(
f"$DM =$ {targetDM:.3f} pc/cm$^3$",
xy=(0.05, 0.925),
xycoords='axes fraction',
color='white',
weight='black',
size=10,
bbox={"boxstyle":"round"}
)
axs[0].set_xlabel("Time (ms)")
axs[0].set_ylabel("Frequency (MHz)")

## Arrival times measurement
arr_result = measureburst(
arr_result, arrtimesdf = measureburst(
filename,
save=False,
show=False,
return_arrivaltimes=True,
**kwargs
)

Expand All @@ -884,11 +925,35 @@ def measure_allmethods(filename, show=True, **kwargs):
print("Arrival Times method:")
print(f"\t{arrdf.iloc[0]['dtdnu (ms/MHz)']:.4e} +/- {arrdf.iloc[0]['dtdnu_err']:.4e} ms/MHz")

## ACF measurement ( -3470 mhz/ms in frbgui)
if len(arrtimesdf) > 0:
axs[0].scatter( # component fit points
arrtimesdf['tstart'],
arrtimesdf['freqs'],
c='w',
edgecolors=arrtimesdf['color'],
marker='o',
s=25,
alpha=np.clip(arrtimesdf['amp'], 0, 1),
label=f"$dt/d\\nu =$ {scilabel(arrdf.iloc[0]['dtdnu (ms/MHz)'], arrdf.iloc[0]['dtdnu_err'])} ms/MHz"
)

axs[0].plot( # <--- This line does not appear in the correct spot
times_ms-pktime,
(1/arrdf.iloc[0]['dtdnu (ms/MHz)'])*(times_ms-pktime) - arrdf.iloc[0]['tb (ms)']/arrdf.iloc[0]['dtdnu (ms/MHz)'],
'w--',
alpha=0.75,
# label=f'$dt/d\\nu = $ {dtdnu:.2e} $\\pm$ {dtdnu_err:.2e}'
label='arrival times fit'
)
axs[0].set_xlim(extent[0], extent[1])
axs[0].set_ylim(extent[2], extent[3])


## ACF measurement (-3470 mhz/ms in frbgui)
p0s = [
[1, pktime, freqs[len(freqs)//2], 0.01, 100, 0], # amp, xo, yo, sigma_x, sigma_y, theta
[1, pktime, 0., freqs[len(freqs)//2], 0.01, 100], # amp, t0, dt, nu0, sigma_t, sigma_nu
[1, pktime, 0., freqs[len(freqs)//2], 0.01, 100] # amp, t0, dnu, nu0, w_t, w_nu
[1, pktime, freqs[len(freqs)//2], p0tw, p0bw, 0], # amp, xo, yo, sigma_x, sigma_y, theta
[1, pktime, 0., freqs[len(freqs)//2], p0tw, p0bw], # amp, t0, dt, nu0, sigma_t, sigma_nu
[1, pktime, 0., freqs[len(freqs)//2], p0tw, p0bw] # amp, t0, dnu, nu0, w_t, w_nu
]

(
Expand Down Expand Up @@ -923,17 +988,20 @@ def measure_allmethods(filename, show=True, **kwargs):
aspect='auto',
cmap='gray',
extent=corrext,
origin='lower'
origin='lower',
)
corrplot.set_clim(0, np.max(corr))
if cpopt[0] > 0:
axs[1].contour(
c = axs[1].contour(
fitmap,
[cpopt[0]/4, cpopt[0]*0.9],
colors='b',
alpha=0.33,
extent=corrext,
)
c.collections[0].set_label(
f"ACF: $dt/d\\nu =$ {scilabel(1/slope, slope_error/slope**2)} ms/MHz"
)

## Gaussian models measurements
models = [driftrate.twoD_Gaussian, driftrate.gaussian_dt, driftrate.gaussian_dnu] # preserve order
Expand Down Expand Up @@ -962,12 +1030,14 @@ def measure_allmethods(filename, show=True, **kwargs):
print(f"\t{slope:.4e} MHz/ms +/- {slope_error}")
print(f"\t{1/slope:.4e} +/- {slope_error/slope**2:.4e} ms/MHz")
precalc_results += [1/slope, slope_error/slope**2]
legend_lbl = f"G: $dt/d\\nu =$ {scilabel(1/slope, slope_error/slope**2)} ms/MHz"
elif model == driftrate.gaussian_dt: # amp, t0, dt, nu0, sigma_t, sigma_nu
units = ['', 'ms', 'ms/MHz', 'MHz', 'ms', 'MHz']
lbls = ['$A_{dt}$', '$t_0$', '$d_t$', '$\\nu_0$', '$\\sigma_t$', '$\\sigma_\\nu$']

print(f"\t{popt[2]:.4e} +/- {perr[2]:.4e} ms/MHz")
precalc_results += [1/slope, slope_error/slope**2]
legend_lbl = f"$d_t =$ {scilabel(popt[2], perr[2])} ms/MHz"
elif model == driftrate.gaussian_dnu: # amp, t0, dnu, nu0, w_t, w_nu
units = ['', 'ms', 'MHz/ms', 'MHz', 'ms', 'MHz']
lbls = ['$A_{dnu}$', '$t_0$', '$d_\\nu$', '$\\nu_0$', '$w_t$', '$w_\\nu$']
Expand All @@ -976,6 +1046,9 @@ def measure_allmethods(filename, show=True, **kwargs):

print(f"\t {gnu_dt = :.4e} ms/MHz")
precalc_results += [gnu_dt, -1] # uncertainty needs to be derived from eq. A4 in Jahns+2023
legend_lbl = ''#f"$d_{{t,g\\nu}} =$ {scilabel(gnu_dt, -1)} ms/MHz"
# if perr[0] == np.inf:
# perr = [-1]*len(perr)

# Output
model_results += list(popt)+list(perr)
Expand All @@ -992,29 +1065,37 @@ def measure_allmethods(filename, show=True, **kwargs):
)

if popt[0] > 0:
axs[0].contour(
c = axs[0].contour(
poptmap,
[popt[0]/4, popt[0]*0.9],
colors='w',
alpha=0.33,
extent=extent,
)
c.collections[0].set_label(legend_lbl)
else: # Bad fit, plot in red
axs[0].contour(
c = axs[0].contour(
poptmap,
[-popt[0]/4, -popt[0]*0.9],
colors='r',
alpha=0.33,
extent=extent,
)
c.collections[0].set_label(legend_lbl)
bname = filename.split('/')[-1].split('.')[0]
axs[0].set_title(bname)

try:
axs[0].legend(handlelength=0)
except IndexError as e:
print("error: weird legend bug")
axs[1].legend(handlelength=0)

results_allmethods = list(arrdf.reset_index().iloc[0]) + precalc_results + model_results

if show: plt.show()
plt.savefig(f"measurements/collected/{bname}")

plt.close()
return results_allmethods

allmethods_columns = [
Expand Down

0 comments on commit 525831a

Please sign in to comment.