import numpy as np
import matplotlib.pyplot as plt
Automatic Aperture Filtering for Microscopy Using Shapes
Overview
The aperture/pupil in the back focal plane (BFP) of a microscope determines the resolution of the image that can be captured by the microscope. For a microscope, the aperture is basically the tube inside the microscope through which light passes. The BFP exists in Fourier space, and the aperture essentially blocks spatial frequencies beyond its width at the BFP. In the corresponding image at the camera, any spatial frequencies beyond the aperture width in the BFP cannot possibly have been made by the sample, because they were blocked by the aperture from getting to the camera. The goal is to utilize this understanding of the inner workings of the microscope to denoise a microscopy image by accounting for the aperture/pupil in the BFP.
We remove high-frequency noise that is beyond the aperture by modeling the pupil function of the image (i.e., the aperture) as a big circle and removing frequencies in the image outside of the aperture. By Fourier transforming the image, we expect to see non-zero power from spatial frequencies out to some radius away from the center – that circle is the aperture. Beyond this circle, the power should be zero, based on our physical understanding of a microscope.
In this example, we vary the radius of the circle to model the Fourier transformed image as zeros outside of the circle, and some number greater than zero inside the circle. Once the best pupil function is found, i.e. the one whose shape best matches the experimentally observed Fourier image, any spatial frequencies in the image beyond this point are zeroed out to low-pass filter that noise out of the image.
We’re using the Rat Hippocampal Neuron .tif file example from FIJI, which has several color channels in this notebook.
## Open the file. It was opened from FIJI and then File > save as > Tif
from skimage import io
= io.imread('Rat_Hippocampal_Neuron.tif').astype('double') ## from FIJI examples
img0
## Use the DIC image to find the aperture
= img0[-1]
img
## Plot image
=plt.subplots(1,1,figsize=(6,6))
fig,ax='Greys_r',origin='lower',vmin=img.mean()-1.*img.std(),vmax=img.mean()+1.*img.std())
ax.imshow(img,cmap'Real Space')
ax.set_title(=False, left=False, labelleft=False, labelbottom=False)
plt.tick_params(bottom plt.show()
## Calculate various Fourier transform things
= np.fft.fftshift(np.fft.fft2(img))
ft = np.sqrt(ft.real**2. + ft.imag**2.)
mag_img = np.unwrap(np.arctan(ft.imag/ft.real))
phase_img = np.log(mag_img) ln_mag
## Plot the magnitude and phase of the Fourier transform of the image
=plt.subplots(1,2,figsize=(12,6))
fig,ax= np.percentile(ln_mag,(1,99))
vmin,vmax 0].imshow(ln_mag,cmap='Greys_r',origin='lower',vmin=vmin,vmax=vmax,
ax[=[-ln_mag.shape[1]/2.,ln_mag.shape[1]/2., -ln_mag.shape[0]/2., ln_mag.shape[0]/2. ])
extent0].set_title('Fourier Space (ln(Magnitude))')
ax[1].imshow(phase_img,cmap='Greys_r',origin='lower',
ax[=[-phase_img.shape[1]/2.,phase_img.shape[1]/2., -phase_img.shape[0]/2., phase_img.shape[0]/2. ])
extent1].set_title('Fourier Space (Phase)')
ax[ plt.show()
This evidence function calculates a circular mask/template given a radius, r0. Any pixels within r0 have value 1. and any outside r0 have value 0. It’s like a 2D tophat function.
The evidence for this template is for m>0, b in R, tau>0 (SI 2.2.2). It is compared against the evidence for a flat or ‘Null’ template (i.e., m=0, b in R, tau >0 (SI 2.2.4).
This function returns the (negative) ratio of those two evidence functions so that the minimzer functions in scipy.optimize can find the maximum.
## pre-computed mask calculation parameters
= ft.shape
nx,ny = np.mgrid[0:nx,0:ny]
kx,ky = kx.astype('double') - nx/2.
kx = ky.astype('double') - ny/2.
ky = kx**2. + ky**2.
kr2
from scipy.special import betainc,betaln,gammaln
def ln_bayes_factor(theta,y):
= theta
r0 ## model is out of bounds
if r0 < 5:
return np.inf
## make the template
= (kr2 < r0**2.).astype('double')
x =float(x.size)
N
= np.nanmean(x)
ex = np.nanmean(x*x)
exx = np.nanmean(y)
ey = np.nanmean(y*y)
eyy = np.nanmean(x*y)
exy = exx - ex*ex + 1e-300
vx = eyy - ey*ey + 1e-300
vy = exy - ex*ey + 1e-300
vxy = vxy/np.sqrt(vx*vy)
r = r*r
r2 if r2 < 1e-10 or r2 > 1.-1e-10:
return np.inf
= (N-2.)/2.
M
= 1e30
delm = np.log(2.) + np.log(delm) - betaln(.5,M) + .5*np.log(vx) - .5*np.log(vy) + M*np.log(1.-r2) - np.log(1.+np.sign(r)*betainc(.5,M,r2))
lnR
return lnR
print(ln_bayes_factor((50.),ln_mag))
-13934.824440777767
### Scan the r0 parameter to see if there is a maximum
= np.linspace(10,400,1000)*1.
xs = np.zeros_like(xs)
ev from tqdm.notebook import trange
for i in trange(xs.size):
= -ln_bayes_factor((xs[i]),ln_mag)
ev[i]
= plt.subplots(1,figsize=(12,4))
fig,ax
ax.plot(xs,ev) plt.show()
### Find the maximum prob aperture mask radius
from scipy.optimize import minimize
def wrapper(initial_guess,ln_mag):
return -1./(1.+np.exp(ln_bayes_factor(initial_guess,ln_mag)))
= np.array((xs[np.nanargmax(ev)]))
initial_guess = minimize(wrapper, initial_guess, args=ln_mag, method='Nelder-Mead')
out print(out)
= out.x
r0 = (kr2 < r0**2.).astype('int') mask
final_simplex: (array([[138.43843844],
[138.43849125]]), array([-1., -1.]))
fun: -1.0
message: 'Optimization terminated successfully.'
nfev: 53
nit: 18
status: 0
success: True
x: array([138.43843844])
### Plot the best mask
=plt.subplots(1,2,figsize=(12,6))
fig,ax0].imshow(ln_mag,cmap='Greys_r',origin='lower',vmin=vmin,vmax=vmax,
ax[=[-ln_mag.shape[1]/2.,ln_mag.shape[1]/2., -ln_mag.shape[0]/2., ln_mag.shape[0]/2.])
extent1].imshow(ln_mag*mask,cmap='Greys_r',origin='lower',vmin=vmin,vmax=vmax,
ax[=[-ln_mag.shape[1]/2.,ln_mag.shape[1]/2., -ln_mag.shape[0]/2., ln_mag.shape[0]/2.])
extent0].set_title('Fourier Space')
ax[1].set_title('Fourier Space (Masked)')
ax[ plt.show()
### Calculate and Plot the low pass filtered image
= ft*(mask+mask*1j)
filtered = np.fft.ifft2(np.fft.fftshift(filtered)).real
filtered = img-filtered
residual
=plt.subplots(3,2,figsize=(12,18))
fig,ax0,0].imshow(img,cmap='Greys_r',origin='lower',vmin=img.mean()-1.*img.std(),vmax=img.mean()+1.*img.std())
ax[0,1].imshow(filtered,cmap='Greys_r',origin='lower',vmin=img.mean()-1.*img.std(),vmax=img.mean()+1.*img.std())
ax[1,0].imshow(img,cmap='Greys_r',origin='lower',vmin=img.mean()-1.*img.std(),vmax=img.mean()+1.*img.std())
ax[1,1].imshow(filtered,cmap='Greys_r',origin='lower',vmin=img.mean()-1.*img.std(),vmax=img.mean()+1.*img.std())
ax[2,0].hist(residual.flatten(),bins=250,log=True)
ax[2,1].imshow(residual,cmap='Greys_r',origin='lower',vmin=np.percentile(residual,35),vmax=np.percentile(residual,65))
ax[
for aa in ax[1]:
64,128)
aa.set_xlim(64,128)
aa.set_ylim(
0,0].set_title('Real Space')
ax[0,0].tick_params(bottom=False, left=False, labelleft=False, labelbottom=False)
ax[0,1].set_title('Real Space (Filtered)')
ax[0,1].tick_params(bottom=False, left=False, labelleft=False, labelbottom=False)
ax[1,0].set_title('Real Space (Zoom)')
ax[1,0].tick_params(bottom=False, left=False, labelleft=False, labelbottom=False)
ax[1,1].set_title('Real Space (Filtered, Zoom)')
ax[1,1].tick_params(bottom=False, left=False, labelleft=False, labelbottom=False)
ax[2,0].set_title('Residual')
ax[2,1].set_title('Residual (Accentuated)')
ax[2,1].tick_params(bottom=False, left=False, labelleft=False, labelbottom=False)
ax[
plt.show()
### Show the original and filtered image (first three fluorescence channels) in RGB
= np.zeros((nx,ny,3))
filtered = np.zeros_like(filtered)
full for i in range(3):
= np.fft.fftshift(np.fft.fft2(img0[i]))
ft = ft*(mask+mask*1j)
fd = np.fft.ifft2(np.fft.fftshift(fd)).real
fd = fd
filtered[:,:,i] = img0[i]
full[:,:,i]
= full.max((0,1))[None,None,:]
scaling /= scaling
filtered /= scaling
full
= plt.subplots(1,2,figsize=(16,8))
fig,ax 0].imshow(full,origin='lower',interpolation='nearest')
ax[0].tick_params(bottom=False, left=False, labelleft=False, labelbottom=False)
ax[1].imshow(filtered,origin='lower',interpolation='nearest')
ax[1].tick_params(bottom=False, left=False, labelleft=False, labelbottom=False)
ax[0].set_title('Original')
ax[1].set_title('Filtered')
ax[
plt.show()
= plt.subplots(1,2,figsize=(16,8))
fig,ax 0].imshow(full,origin='lower',interpolation='nearest')
ax[1].imshow(filtered,origin='lower',interpolation='nearest')
ax[0].tick_params(bottom=False, left=False, labelleft=False, labelbottom=False)
ax[1].tick_params(bottom=False, left=False, labelleft=False, labelbottom=False)
ax[0].set_title('Original (Zoom)')
ax[1].set_title('Filtered (Zoom)')
ax[for aa in ax:
200,328)
aa.set_xlim(200,328)
aa.set_ylim( plt.show()