import numpy as np import scipy as sc from scipy import signal from mpi4py import MPI comm = MPI.COMM_WORLD size = comm.Get_size() rank = comm.Get_rank() # kt-BLAST (NO DC TERM) method for reconstruction of undersampled MRI image based on # l2 minimization. def EveryAliased3D2(i,j,k,PP,Nx,Ny,Nz,BB,R): ivec = [i,j,k] Nvec = [Nx,Ny,Nz] [ktot,ltot] = PP.shape Ptot = np.zeros([ktot**ltot,ltot]) PP2 = np.zeros([ktot**ltot,ltot]) tt = -1 for kk in range(Ptot.shape[0]): nn = int(np.mod(kk,3)) mm = int(np.mod(np.floor(kk/3),3)) if np.mod(kk,9)==0: tt+=1 Ptot[kk,0] = PP[tt,0] + ivec[0] Ptot[kk,1] = PP[mm,1] + ivec[1] Ptot[kk,2] = PP[nn,2] + ivec[2] for kk in range(Ptot.shape[0]): for ll in range(Ptot.shape[1]): if Ptot[kk,ll]<0: Ptot[kk,ll] = Ptot[kk,ll] + Nvec[ll] if Ptot[kk,ll]>=Nvec[ll]: Ptot[kk,ll] = Ptot[kk,ll] - Nvec[ll] CC = np.zeros([3,Ptot.shape[0]+1]) YY = np.array([ [i] , [j], [k] ]) CC[0,0] = i CC[1,0] = j CC[2,0] = k psel = 0 for l in range(1,Ptot.shape[0]+1): CC[0,l] = int(Ptot[l-1,0]) CC[1,l] = int(Ptot[l-1,1]) CC[2,l] = int(Ptot[l-1,2]) if CC[0,l]==YY[0,psel] and CC[1,l]==YY[1,psel] and CC[2,l]==YY[2,psel] and BB[int(CC[1,l]),int(CC[2,l]),int(CC[0,l])]!=0: pass else: War = False for ww in range(psel): if CC[0,l]==YY[0,ww] and CC[1,l]==YY[1,ww] and CC[2,l]==YY[2,ww] and BB[int(CC[1,l]),int(CC[2,l]),int(CC[0,l])]!=0: War = True if not War: psel += 1 CCC = np.array([ [CC[0,l] ] , [CC[1,l]] , [CC[2,l]]]) YY = np.concatenate( ( YY, CCC ) ,axis=1 ) return YY.astype(int) def EveryAliased3D(i,j,k,DP,Nx,Ny,Nz,BB,R,SPREAD=None): ivec = [i,j,k] Nvec = [Nx,Ny,Nz] [ktot,ltot] = DP.shape DPN = np.zeros([ktot,ltot]) if SPREAD is not None: # WITH SPREAD FUNCTIONS FORMALISM Maux = np.zeros([Ny,Nz,Nx]) Maux[j,k,i] = 1 SP2 = SPREAD[::-1,::-1,::-1] MS = R*sc.signal.convolve(Maux,SP2, mode='same') ms = np.abs(MS) Ims = 1*(ms>np.max(ms)*0.405) Pas = np.where(Ims==1) PP = np.array(Pas[:]) PEA = PP[::-1,:] for ll in range(PEA.shape[1]): if PEA[0,ll]>=Nx: PEA[0,ll] = PEA[0,ll] - Nx if PEA[1,ll]>=Ny: PEA[1,ll] = PEA[1,ll] - Ny if PEA[2,ll]>=Nz: PEA[2,ll] = PEA[2,ll] - Nz Ntot = PEA.shape[1] ind = 0 PEAnew = PEA for ll in range(Ntot): if BB[PEA[1,ll],PEA[2,ll],PEA[0,ll]]!=0: PEAnew = np.delete(PEAnew,(ll-ind),axis=1) ind +=1 return PEA else: for kk in range(DPN.shape[0]): for l in range(DPN.shape[1]): DPN[kk,l] = DP[kk,l] + ivec[l] if DPN[kk,l]<0: DPN[kk,l] = DPN[kk,l] + Nvec[l] if DPN[kk,l]>=Nvec[l]: DPN[kk,l] = DPN[kk,l] - Nvec[l] CC = np.zeros([3,ktot+1]) YY = np.array([ [i] , [j], [k] ]) CC[0,0] = i CC[1,0] = j CC[2,0] = k for l in range(1,ktot+1): CC[0,l] = DPN[l-1,0] CC[1,l] = DPN[l-1,1] CC[2,l] = DPN[l-1,2] if CC[0,l]!=CC[0,l-1] and CC[1,l]!=CC[1,l-1] and CC[2,l]!=CC[2,l-1] and BB[int(CC[1,l]),int(CC[2,l]),int(CC[0,l])]==0: CCC = np.array([ [CC[0,l] ] , [CC[1,l]] , [CC[2,l]]]) YY = np.concatenate( ( YY, CCC ) ,axis=1 ) return YY.astype(int) def EveryAliased(i,j,DP,Nx,Ny,BB,R,mode): if mode==1: # USING GEOMETRICAL ASSUMPTIONS ivec = [i,j] Nvec = [Nx,Ny] DPN = 0*DP [ktot,ltot] = DP.shape for k in range(ktot): for l in range(ltot): DPN[k,l] = DP[k,l] + ivec[l] if DPN[k,l]<0: #DPN[k,l] = ivec[l] DPN[k,l] = DPN[k,l] + Nvec[l] if DPN[k,l]>=Nvec[l]: #DPN[k,l] = ivec[l] DPN[k,l] = DPN[k,l] - Nvec[l] CC = np.zeros([2,ktot+1]) YY = np.array([ [i] , [j] ]) CC[0,0] = i CC[1,0] = j for l in range(1,ktot+1): CC[0,l] = DPN[l-1,0] CC[1,l] = DPN[l-1,1] if CC[0,l]!=CC[0,l-1] and CC[1,l]!=CC[1,l-1] and BB[int(CC[1,l]),int(CC[0,l])]==0: CCC = np.array([ [CC[0,l] ] , [CC[1,l]] ]) YY = np.concatenate( ( YY, CCC ) ,axis=1 ) return YY.astype(int) if mode=='spread': # WITH SPREAD FUNCTIONS FORMALISM Maux = np.zeros([row,numt2]) Maux[l,k] = 1 MS = R*ConvolSP(Maux,SPREAD) ms = np.abs(MS) Ims = 1*(ms>np.max(ms)*0.405) Pas = np.where(Ims==1) PP = np.array(Pas[:]) return PP[::-1,:] def GetSymmetric(M): [row,numt2] = M.shape S = np.zeros(M.shape,dtype=complex) aux = np.zeros([1,row]) nmid = 0.5*(numt2+1) for k in range(int(nmid)): aux = 0.5*( M[:,k] + M[:,numt2-k-1] ) S[:,k] = aux S[:,numt2-k-1] = aux return S def UNDER(A,mode,R,k): start = np.mod(k,R) I1B = np.zeros(A.shape,dtype=complex) # Not quite efficient ! better to work with vectors if mode=='ky': for k in range(start,A.shape[0],R): I1B[k,:,:] = A[k,:,:] if mode=='kxky': for k in range(start,A.shape[0],R): for l in range(start,A.shape[1],R): I1B[k,l,:] = A[k,l,:] if mode=='kxkykz': for k in range(start,A.shape[0],R): for l in range(start,A.shape[2],R): for r in range(start,A.shape[1],R): I1B[k,r,l] = A[k,r,l] return I1B def FilteringHigh(M,fac): if M.ndim==2: [row,col] = M.shape inx = np.linspace(0,col-1,col) MF = np.zeros(M.shape,dtype=complex) for k in range(row): vecfou = np.fft.fft(M[k,:]) window = signal.tukey(2*col,fac) vecfou2 = vecfou*window[col:2*col] MF[k,:] = np.fft.ifft(vecfou2) return MF if M.ndim==3: [row,col,dep] = M.shape MF = np.zeros(M.shape,dtype=complex) inx = np.linspace(0,col-1,col) for l in range(dep): for k in range(row): vecfou = np.fft.fft(M[k,:,l]) window = signal.tukey(2*col,fac) vecfou2 = vecfou*window[col:2*col] MF[k,:,l] = np.fft.ifft(vecfou2) return MF def InterpolateM(M,numt2): if M.ndim==2: [row,numt] = M.shape MNew = np.zeros([row,numt2],dtype=complex) xdat = np.linspace(0,numt,numt) xdat2 = np.linspace(0,numt,numt2) nstar = int(0.5*(numt2-numt)) for t in range(nstar,nstar+numt): MNew[:,t] = M[:,t-nstar] for l in range(row): ydat = M[l,:] fdat = sc.interpolate.interp1d(xdat,ydat,kind='cubic') MNew[l,1:nstar] = fdat(xdat2)[1:nstar] MNew[l,nstar+numt:numt2] = fdat(xdat2)[nstar+numt:numt2] if M.ndim==3: [row,col,numt] = M.shape MNew = np.zeros([row,col,numt2],dtype=complex) xdat = np.linspace(0,numt,numt) xdat2 = np.linspace(0,numt,numt2) nstar = int(0.5*(numt2-numt)) for t in range(nstar,nstar+numt): MNew[:,:,t] = M[:,:,t-nstar] for c in range(col): for l in range(row): ydat = M[l,c,:] fdat = sc.interpolate.interp1d(xdat,ydat,kind='cubic') MNew[l,c,1:nstar] = fdat(xdat2)[1:nstar] MNew[l,c,nstar+numt:numt2] = fdat(xdat2)[nstar+numt:numt2] return MNew def KTT(M,scol): #Maux = M[:,scol,:] #Maux = np.fft.ifftshift(Maux,axes=0) #Maux = np.fft.ifft(Maux,axis=0) #Maux = np.fft.ifft(Maux,axis=1) #Maux = np.fft.fftshift(Maux,axes=1) # TAO STYLE Maux = np.zeros(M.shape,dtype=complex) for k in range(M.shape[2]): Maux[:,:,k] = np.fft.ifftshift(M[:,:,k]) Maux[:,:,k] = np.fft.ifft2(Maux[:,:,k]) Maux = Maux[:,scol,:] Maux = np.fft.ifft(Maux,axis=1) Maux = np.fft.fftshift(Maux,axes=1) return Maux def IKTT(M): #Maux = np.fft.ifftshift(M,axes=1) #Maux = np.fft.ifft(Maux,axis=1) #Maux = np.fft.fft(Maux,axis=0) #Maux = np.fft.fftshift(Maux,axes=0) # TAO STYLE Maux = np.fft.ifftshift(M,axes=1) Maux = np.fft.fft(Maux,axis=1) return Maux def KTT3D(M,sdep): Maux = np.zeros(M.shape,dtype=complex) for k in range(M.shape[3]): Maux[:,:,:,k] = np.fft.ifftshift(M[:,:,:,k]) Maux[:,:,:,k] = np.fft.ifftn(Maux[:,:,:,k]) Maux = Maux[:,:,sdep,:] Maux = np.fft.ifft(Maux,axis=2) Maux = np.fft.fftshift(Maux,axes=2) return Maux def IKTT3D(M): Maux = np.fft.ifftshift(M,axes=2) Maux = np.fft.fft(Maux,axis=2) return Maux def get_points4D(row,col,dep,numt2,R,mode): bb = np.ceil(row/R) cc = np.ceil(col/R) aa = np.ceil(numt2/R) points = R+1 kmid = int(R/2) if mode=='kxky': PC = [np.ceil(numt2/2),np.ceil(row/2),np.ceil(col/2)] PP = np.zeros([points,3]) DP = np.zeros([points-1,3]) for k in range(points): PP[k,0] = numt2-aa*(k) + 1 PP[k,1] = bb*(k) PP[k,2] = cc*(k) if PP[k,0]>=numt2: PP[k,0] -= 1 if PP[k,0]<0: PP[k,0] += 1 if PP[k,1]>=row: PP[k,1] -= 1 if PP[k,1]<0: PP[k,1] += 1 if PP[k,2]>=col: PP[k,2] -= 1 if PP[k,2]<0: PP[k,2] += 1 if kkmid: DP[k-1,0] = PP[k,0] - PC[0] DP[k-1,1] = PP[k,1] - PC[1] DP[k-1,2] = PP[k,2] - PC[2] kmax = int((PP[kmid,0] + PP[kmid-1,0])/2 ) kmin = int((PP[kmid,0] + PP[kmid+1,0])/2 ) cmax = int((PP[kmid,1] + PP[kmid-1,1])/2 ) cmin = int((PP[kmid,1] + PP[kmid+1,1])/2 ) #DP2 = np.zeros([DP.shape[0]**DP.shape[1],DP.shape[1]]) #DP2[0,0] = DP[0,0]; DP2[0,1] = DP[0,1] ; DP2[0,2] = DP[0,2] #DP2[1,0] = DP[0,0]; DP2[1,1] = DP[0,1] ; DP2[1,2] = DP[1,2] #DP2[2,0] = DP[0,0]; DP2[2,1] = DP[1,1] ; DP2[2,2] = DP[0,2] #DP2[3,0] = DP[0,0]; DP2[3,1] = DP[1,1] ; DP2[3,2] = DP[1,2] #DP2[4,0] = DP[1,0]; DP2[4,1] = DP[0,1] ; DP2[4,2] = DP[0,2] #DP2[5,0] = DP[1,0]; DP2[5,1] = DP[0,1] ; DP2[5,2] = DP[1,2] #DP2[6,0] = DP[1,0]; DP2[6,1] = DP[1,1] ; DP2[6,2] = DP[0,2] #P2[7,0] = DP[1,0]; DP2[7,1] = DP[1,1] ; DP2[7,2] = DP[1,2] return [kmin,kmax,PP,DP] if mode=='ky': PC = [np.ceil(numt2/2),np.ceil(row/2)] PP = np.zeros([points,2]) DP = np.zeros([points-1,2]) for k in range(points): PP[k,0] = numt2-(aa-1)*(k) PP[k,1] = bb*(k) if kkmid: DP[k-1,0] = PP[k,0] - PC[0] DP[k-1,1] = PP[k,1] - PC[1] kmax = int((PP[kmid,0] + PP[kmid-1,0])/2 ) kmin = int((PP[kmid,0] + PP[kmid+1,0])/2 ) return [kmin,kmax,PP,DP] def SpreadPoint3D(M,R,sdep): [row,col,dep,numt2] = M.shape PS = np.zeros([row,col,dep,numt2],dtype=complex) inx = 0 iny = 0 for k in range(np.mod(inx,R),row,R): for ss in range(np.mod(iny,R),col,R): PS[k,ss,:,:] = 1 iny = iny + 1 inx = inx + 1 for k in range(numt2): PS[:,:,:,k] = np.fft.ifftn(PS[:,:,:,k]) PS[:,:,:,k] = np.fft.fftshift(PS[:,:,:,k]) SPREAD = PS[:,:,sdep,:] SPREAD = np.fft.ifft(SPREAD,axis=2) SPREAD = np.fft.fftshift(SPREAD,axes=2) return SPREAD def SpreadPoint(M,R,scol): [row,col,numt2] = M.shape PS = np.zeros([row,col,numt2],dtype=complex) inx = 0 for l in range(0,numt2): for k in range(np.mod(inx,R),row,R): PS[k,:,l] = 1 inx = inx + 1 #PS = 1*(M!=0) #PS = 0*M + 1 #SPREAD = KTT(PS,0) for k in range(numt2): PS[:,:,k] = np.fft.ifft2(PS[:,:,k]) PS[:,:,k] = np.fft.fftshift(PS[:,:,k]) SPREAD = PS[:,scol,:] SPREAD = np.fft.ifft(SPREAD,axis=1) SPREAD = np.fft.fftshift(SPREAD,axes=1) return SPREAD def ConvolSP(M1,M2): M2 = M2[::-1,::-1] M3 = sc.signal.convolve2d(M1,M2, boundary='wrap', mode='same') return M3 def KTBLASTMETHOD_4D_kxky(ITOT,R,mode): ################################################################### # Training Stage # ################################################################### [row,col,dep,numt2] = ITOT.shape # INPUT PARAMETERS iteshort = 1 tetest = int(dep/2) numt = int(numt2) Dyy = int(row*0.1) rmid = int(row/2) cmid = int(col/2) TKdata = np.zeros(ITOT.shape,dtype=complex) UKdata = np.zeros(ITOT.shape,dtype=complex) Kdata = np.zeros(ITOT.shape,dtype=complex) Kdata_NEW0 = np.zeros(ITOT.shape,dtype=complex) Kdata_NEW = np.zeros(ITOT.shape,dtype=complex) KTBLAST0 = np.zeros(ITOT.shape,dtype=complex) KTBLAST = np.zeros(ITOT.shape,dtype=complex) for k in range(numt2): # THE FULL KSPACE Kdata[:,:,:,k] = np.fft.fftn(ITOT[:,:,:,k]) Kdata[:,:,:,k] = np.fft.fftshift(Kdata[:,:,:,k]) # UNDERSAMPLING STEP AND FILLED WITH ZEROS THE REST UKdata[:,:,:,k] = UNDER(Kdata[:,:,:,k],mode,R,k) # GENERATING THE TRAINING DATA WITH SUBSAMPLING the Center IN KX , KY for k in range(numt): TKdata[rmid-Dyy:rmid+Dyy+1,cmid-Dyy:cmid+Dyy+1,:,k] = Kdata[rmid-Dyy:rmid+Dyy+1,cmid-Dyy:cmid+Dyy+1,:,k] [kmin,kmax,PP,DP] = get_points4D(row,col,dep,numt2,R,mode) if iteshort==1: print(PP) print(DP) print('range of k = ',kmin,kmax) SPREAD = SpreadPoint3D(UKdata,R,tetest) ################################################################### # RECONSTRUCTION # ################################################################### ZE1 = iteshort + (tetest-1)*(iteshort) ZE2 = (tetest+1)*(iteshort) + (dep)*(1-iteshort) for zi in range(ZE1,ZE2): if rank==0: print('4D KTBLAST: R = ' + str(R) + ' and z = ' + str(zi)+'/'+str(dep)) ### CONSTRUCT THE REFERENCE M_TRAINING B2 = KTT3D(TKdata,zi) B2 = FilteringHigh(B2,0.3) M2 = 4*np.abs(B2)**2 #M2 = GetSymmetric(M2) ### INTERPOLATE IF NUMT0.001) if scantype=='0G': PHASE0[:,:,k] = (gamma*B0*TE+0.01*X)*(np.abs(Sq[:,:,k])>0.001) + 10*varPHASE0 PHASE1[:,:,k] = (gamma*B0*TE+0.01*X)*(np.abs(Sq[:,:,k])>0.001) + 10*varPHASE0 + np.pi*Sq[:,:,k]/VENC if scantype=='-G+G': PHASE0[:,:,k] = gamma*B0*TE*np.ones([row,col]) + 10*varPHASE0 - np.pi*Sq[:,:,k]/VENC PHASE1[:,:,k] = gamma*B0*TE*np.ones([row,col]) + 10*varPHASE0 + np.pi*Sq[:,:,k]/VENC RHO0[:,:,k] = modulus*np.cos(PHASE0[:,:,k]) + Drho + 1j*modulus*np.sin(PHASE0[:,:,k]) + 1j*Drho2 RHO1[:,:,k] = modulus*np.cos(PHASE1[:,:,k]) + Drho + 1j*modulus*np.sin(PHASE1[:,:,k]) + 1j*Drho2 if np.ndim(Sq)==4: [row,col,dep,numt2] = Sq.shape [X,Y,Z] = np.meshgrid(np.linspace(0,col,col),np.linspace(0,row,row),np.linspace(0,dep,dep)) for k in range(numt2): if noise: Drho = np.random.normal(0,0.2,[row,col,dep]) Drho2 = np.random.normal(0,0.2,[row,col,dep]) else: Drho = np.zeros([row,col,dep]) Drho2 = np.zeros([row,col,dep]) varPHASE0 = np.random.randint(-10,11,size=(row,col,dep))*np.pi/180*(np.abs(Sq[:,:,:,k])<0.001) modulus = 0.5 + 0.5*(np.abs(Sq[:,:,:,k])>0.001) if scantype=='0G': PHASE0[:,:,:,k] = (gamma*B0*TE+0.01*X)*(np.abs(Sq[:,:,:,k])>0.001) + 10*varPHASE0 PHASE1[:,:,:,k] = (gamma*B0*TE+0.01*X)*(np.abs(Sq[:,:,:,k])>0.001) + 10*varPHASE0 + np.pi*Sq[:,:,:,k]/VENC if scantype=='-G+G': PHASE0[:,:,:,k] = gamma*B0*TE*np.ones([row,col,dep]) + varPHASE0 - np.pi*Sq[:,:,:,k]/VENC PHASE1[:,:,:,k] = gamma*B0*TE*np.ones([row,col,dep]) + varPHASE0 + np.pi*Sq[:,:,:,k]/VENC RHO0[:,:,:,k] = modulus*np.cos(PHASE0[:,:,:,k]) + Drho + 1j*modulus*np.sin(PHASE0[:,:,:,k]) + 1j*Drho2 RHO1[:,:,:,k] = modulus*np.cos(PHASE1[:,:,:,k]) + Drho + 1j*modulus*np.sin(PHASE1[:,:,:,k]) + 1j*Drho2 return [RHO0,RHO1] def undersampling(Sqx,Sqy,Sqz,options,savepath): R = options['kt-BLAST']['R'] mode = options['kt-BLAST']['mode'] transpose = True for r in R: if rank==0: print('Using Acceleration Factor R = ' + str(r)) print('Component x of M0') [M0,M1] = GenerateMagnetization(Sqx,options['kt-BLAST']['VENC'],options['kt-BLAST']['noise'],scantype='0G') if transpose: M0 = M0.transpose((0,2,1,3)) M1 = M1.transpose((0,2,1,3)) if mode=='ky': M0_kt = KTBLASTMETHOD_4D_ky(M0,r,mode) if mode=='kxky': M0_kt = KTBLASTMETHOD_4D_kxky(M0,r,mode) if rank==0: print('\n Component x of M1') if mode=='ky': M1_kt = KTBLASTMETHOD_4D_ky(M1,r,mode) if mode=='kxky': M1_kt = KTBLASTMETHOD_4D_kxky(M1,r,mode) Sqx_kt = phase_contrast(M1_kt,M0_kt,options['kt-BLAST']['VENC'],scantype='0G') del M0,M1 del M0_kt, M1_kt [M0,M1] = GenerateMagnetization(Sqy,options['kt-BLAST']['VENC'],options['kt-BLAST']['noise'],scantype='0G') if transpose: M0 = M0.transpose((0,2,1,3)) M1 = M1.transpose((0,2,1,3)) if rank==0: print('\n Component y of M0') if mode=='ky': M0_kt = KTBLASTMETHOD_4D_ky(M0,r,mode) if mode=='kxky': M0_kt = KTBLASTMETHOD_4D_kxky(M0,r,mode) if rank==0: print('\n Component y of M1') if mode=='ky': M1_kt = KTBLASTMETHOD_4D_ky(M1,r,mode) if mode=='kxky': M1_kt = KTBLASTMETHOD_4D_kxky(M1,r,mode) Sqy_kt = phase_contrast(M1_kt,M0_kt,options['kt-BLAST']['VENC'],scantype='0G') del M0,M1 del M0_kt, M1_kt [M0,M1] = GenerateMagnetization(Sqz,options['kt-BLAST']['VENC'],options['kt-BLAST']['noise'],scantype='0G') if transpose: M0 = M0.transpose((0,2,1,3)) M1 = M1.transpose((0,2,1,3)) if rank==0: print('\n Component z of M0') if mode=='ky': M0_kt = KTBLASTMETHOD_4D_ky(M0,r,mode) if mode=='kxky': M0_kt = KTBLASTMETHOD_4D_kxky(M0,r,mode) if rank==0: print('\n Component z of M1') if mode=='ky': M1_kt = KTBLASTMETHOD_4D_ky(M1,r,mode) if mode=='kxky': M1_kt = KTBLASTMETHOD_4D_kxky(M1,r,mode) if rank==0: print(' ') Sqz_kt = phase_contrast(M1_kt,M0_kt,options['kt-BLAST']['VENC'],scantype='0G') if transpose: Sqx_kt = Sqx_kt.transpose((0,2,1,3)) Sqy_kt = Sqy_kt.transpose((0,2,1,3)) Sqz_kt = Sqz_kt.transpose((0,2,1,3)) if options['kt-BLAST']['save']: if rank==0: print('saving the sequences in ' + savepath) seqname = options['kt-BLAST']['name'] +'_R' + str(r) + '.npz' print('sequence name: ' + seqname) np.savez_compressed( savepath + seqname, x=Sqx_kt, y=Sqy_kt,z=Sqz_kt) del Sqx_kt,Sqy_kt,Sqz_kt