source: trunk/python/sbseparator.py @ 2647

Last change on this file since 2647 was 2647, checked in by Kana Sugimoto, 12 years ago

New Development: Yes

JIRA Issue: Yes (CAS-4141/CSV-1526)

Ready for Test: Yes

Interface Changes: Yes

What Interface Changed: a new module

Test Programs:

Put in Release Notes: Yes

Module(s): itself (sbseparator)

Description:

A new module, sbseparator.
This is to separate or supress emission from image sideband
of spectra taken with frequency switching. FFT technique is
used in this algorithm (see Emerson, Klein, and Haslam 1979, A&A 76, 92).
So far, this module is mainly for testing purpose to see how
the algorithm works for real data.


File size: 22.3 KB
Line 
1import os, shutil
2import numpy
3import numpy.fft as FFT
4import math
5
6from asap.scantable import scantable
7from asap.parameters import rcParams
8from asap.logging import asaplog, asaplog_post_dec
9from asap.selector import selector
10from asap.asapgrid import asapgrid2
11#from asap._asap import sidebandsep
12
13class sbseparator:
14    """
15    The sbseparator class is defined to separate USB and LSB spectra
16    observed with DSB reciever. It also helps supressing emmission of
17    image sideband.
18
19    Example:
20        # Create sideband separator instance whith 3 input data
21        sbsep = sbseparator(['test1.asap', 'test2.asap', 'test3.asap'])
22        # Set reference IFNO and tolerance to select data
23        sbsep.set_frequency(5, 30, frame='TOPO')
24        # Set direction tolerance to select data in unit of radian
25        sbsep.set_dirtol(1.e-5)
26        # Set rejection limit of solution
27        sbsep.set_limit(0.2)
28        # Solve image sideband as well
29        sbsep.set_both(True)
30        # Invoke sideband separation
31        sbsep.separate('testout.asap', overwrite = True)
32    """
33    def __init__(self, infiles):
34        self.intables = None
35        self.signalShift = []
36        self.imageShift = []
37        self.dsbmode = True
38        self.getboth = False
39        self.rejlimit = 0.2
40        self.baseif = -1
41        self.freqtol = 10.
42        self.freqframe = ""
43        self.solveother = False
44        self.dirtol = [1.e-5, 1.e-5] # direction tolerance in rad (2 arcsec)
45
46        self.tables = []
47        self.nshift = -1
48        self.nchan = -1
49
50        self.set_data(infiles)
51       
52        #self.separator = sidebandsep()
53
54    @asaplog_post_dec
55    def set_data(self, infiles):
56        """
57        Set data to be processed.
58
59        infiles  : a list of filenames or scantables
60        """
61        if not (type(infiles) in (list, tuple, numpy.ndarray)):
62            infiles = [infiles]
63        if isinstance(infiles[0], scantable):
64            # a list of scantable
65            for stab in infiles:
66                if not isinstance(stab, scantable):
67                    asaplog.post()
68                    raise TypeError, "Input data is not a list of scantables."
69
70            #self.separator._setdata(infiles)
71            self._reset_data()
72            self.intables = infiles
73        else:
74            # a list of filenames
75            for name in infiles:
76                if not os.path.exists(name):
77                    asaplog.post()
78                    raise ValueError, "Could not find input file '%s'" % name
79           
80            #self.separator._setdataname(infiles)
81            self._reset_data()
82            self.intables = infiles
83
84        asaplog.push("%d files are set to process" % len(self.intables))
85
86
87    def _reset_data(self):
88        del self.intables
89        self.intables = None
90        self.signalShift = []
91        #self.imageShift = []
92        self.tables = []
93        self.nshift = -1
94        self.nchan = -1
95
96    @asaplog_post_dec
97    def set_frequency(self, baseif, freqtol, frame=""):
98        """
99        Set IFNO and frequency tolerance to select data to process.
100
101        Parameters:
102          - reference IFNO to process in the first table in the list
103          - frequency tolerance from reference IF to select data
104          frame  : frequency frame to select IF
105        """
106        self._reset_if()
107        self.baseif = baseif
108        if isinstance(freqtol,dict) and freqtol["unit"] == "Hz":
109            if freqtol['value'] > 0.:
110                self.freqtol = freqtol
111            else:
112                asaplog.post()
113                asaplog.push("Frequency tolerance should be positive value.")
114                asaplog.post("ERROR")
115                return
116        else:
117            # torelance in channel unit
118            if freqtol > 0:
119                self.freqtol = float(freqtol)
120            else:
121                asaplog.post()
122                asaplog.push("Frequency tolerance should be positive value.")
123                asaplog.post("ERROR")
124                return
125        self.freqframe = frame
126
127    def _reset_if(self):
128        self.baseif = -1
129        self.freqtol = 0
130        self.freqframe = ""
131        self.signalShift = []
132        #self.imageShift = []
133        self.tables = []
134        self.nshift = 0
135        self.nchan = -1
136
137    @asaplog_post_dec
138    def set_dirtol(self, dirtol=[1.e-5,1.e-5]):
139        """
140        Set tolerance of direction to select data
141        """
142        # direction tolerance in rad
143        if not (type(dirtol) in [list, tuple, numpy.ndarray]):
144            dirtol = [dirtol, dirtol]
145        if len(dirtol) == 1:
146            dirtol = [dirtol[0], dirtol[0]]
147        if len(dirtol) > 1:
148            self.dirtol = dirtol[0:2]
149        else:
150            asaplog.post()
151            asaplog.push("Invalid direction tolerance. Should be a list of float in unit radian")
152            asaplog.post("ERROR")
153            return
154        asaplog.post("Set direction tolerance [%f, %f] (rad)" % \
155                     (self.dirtol[0], self.dirtol[1]))
156
157    @asaplog_post_dec
158    def set_shift(self, mode="DSB", imageshift=None):
159        """
160        Set shift mode and channel shift of image band.
161
162        mode       : shift mode ['DSB'|'SSB']
163                     When mode='DSB', imageshift is assumed to be equal
164                     to the shift of signal sideband but in opposite direction.
165        imageshift : a list of number of channel shift in image band of
166                     each scantable. valid only mode='SSB'
167        """
168        if mode.upper().startswith("S"):
169            if not imageshift:
170                raise ValueError, "Need to set shift value of image sideband"
171            self.dsbmode = False
172            self.imageShift = imageshift
173            asaplog.push("Image sideband shift is set manually: %s" % str(self.imageShift))
174        else:
175            # DSB mode
176            self.dsbmode = True
177            self.imageShift = []
178
179    @asaplog_post_dec
180    def set_both(self, flag=False):
181        """
182        Resolve both image and signal sideband when True.
183        """
184        self.getboth = flag
185        if self.getboth:
186            asaplog.push("Both signal and image sidebands are solved and output as separate tables.")
187        else:
188            asaplog.push("Only signal sideband is solved and output as an table.")
189
190    @asaplog_post_dec
191    def set_limit(self, threshold=0.2):
192        """
193        Set rejection limit of solution.
194        """
195        #self.separator._setlimit(abs(threshold))
196        self.rejlimit = threshold
197        asaplog.push("The threshold of rejection is set to %f" % self.rejlimit)
198
199
200    @asaplog_post_dec
201    def set_solve_other(self, flag=False):
202        """
203        Calculate spectra by subtracting the solution of the other sideband
204        when True.
205        """
206        self.solveother = flag
207        if flag:
208            asaplog.push("Expert mode: solution are obtained by subtraction of the other sideband.")
209
210
211    @asaplog_post_dec
212    def separate(self, outname="", overwrite=False):
213        """
214        Invoke sideband separation.
215
216        outname   : a name of output scantable
217        overwrite : overwrite existing table
218        """
219        # List up valid scantables and IFNOs to convolve.
220        #self.separator._separate()
221        self._setup_shift()
222        #self._preprocess_tables()
223
224        nshift = len(self.tables)
225        signaltab = self._grid_outtable(self.tables[0].copy())
226        if self.getboth:
227            imagetab = signaltab.copy()
228
229        rejrow = []
230        for irow in xrange(signaltab.nrow()):
231            currpol = signaltab.getpol(irow)
232            currbeam = signaltab.getbeam(irow)
233            currdir = signaltab.get_directionval(irow)
234            spec_array, tabidx = self._get_specarray(polid=currpol,\
235                                                     beamid=currbeam,\
236                                                     dir=currdir)
237            #if not spec_array:
238            if len(tabidx) == 0:
239                asaplog.post()
240                asaplog.push("skipping row %d" % irow)
241                rejrow.append(irow)
242                continue
243            signal = self._solve_signal(spec_array, tabidx)
244            signaltab.set_spectrum(signal, irow)
245            if self.getboth:
246                image = self._solve_image(spec_array, tabidx)
247                imagetab.set_spectrum(image, irow)
248       
249        # TODO: Need to remove rejrow form scantables here
250        signaltab.flag_row(rejrow)
251        if self.getboth:
252            imagetab.flag_row(rejrow)
253       
254        if outname == "":
255            outname = "sbsepareted.asap"
256        signalname = outname + ".signalband"
257        if os.path.exists(signalname):
258            if not overwrite:
259                raise Exception, "Output file '%s' exists." % signalname
260            else:
261                shutil.rmtree(signalname)
262        signaltab.save(signalname)
263        if self.getboth:
264            imagename = outname + ".imageband"
265            if os.path.exists(imagename):
266                if not overwrite:
267                    raise Exception, "Output file '%s' exists." % imagename
268                else:
269                    shutil.rmtree(imagename)
270            imagetab.save(imagename)
271
272
273    def _solve_signal(self, data, tabidx=None):
274        if not tabidx:
275            tabidx = range(len(data))
276
277        tempshift = []
278        dshift = []
279        if self.solveother:
280            for idx in tabidx:
281                tempshift.append(-self.imageShift[idx])
282                dshift.append(self.signalShift[idx] - self.imageShift[idx])
283        else:
284            for idx in tabidx:
285                tempshift.append(-self.signalShift[idx])
286                dshift.append(self.imageShift[idx] - self.signalShift[idx])
287
288        shiftdata = numpy.zeros(data.shape, numpy.float)
289        for i in range(len(data)):
290            shiftdata[i] = self._shiftSpectrum(data[i], tempshift[i])
291        ifftdata = self._Deconvolution(shiftdata, dshift, self.rejlimit)
292        result_image = self._combineResult(ifftdata)
293        if not self.solveother:
294            return result_image
295        result_signal = self._subtractOtherSide(shiftdata, dshift, result_image)
296        return result_signal
297
298
299    def _solve_image(self, data, tabidx=None):
300        if not tabidx:
301            tabidx = range(len(data))
302
303        tempshift = []
304        dshift = []
305        if self.solveother:
306            for idx in tabidx:
307                tempshift.append(-self.signalShift[idx])
308                dshift.append(self.imageShift[idx] - self.signalShift[idx])
309        else:
310            for idx in tabidx:
311                tempshift.append(-self.imageShift[idx])
312                dshift.append(self.signalShift[idx] - self.imageShift[idx])
313
314        shiftdata = numpy.zeros(data.shape, numpy.float)
315        for i in range(len(data)):
316            shiftdata[i] = self._shiftSpectrum(data[i], tempshift[i])
317        ifftdata = self._Deconvolution(shiftdata, dshift, self.rejlimit)
318        result_image = self._combineResult(ifftdata)
319        if not self.solveother:
320            return result_image
321        result_signal = self._subtractOtherSide(shiftdata, dshift, result_image)
322        return result_signal
323
324    @asaplog_post_dec
325    def _grid_outtable(self, table):
326        # Generate gridded table for output (Just to get rows)
327        gridder = asapgrid2(table)
328        gridder.setIF(self.baseif)
329       
330        cellx = str(self.dirtol[0])+"rad"
331        celly = str(self.dirtol[1])+"rad"
332        dirarr = numpy.array(table.get_directionval()).transpose()
333        mapx = dirarr[0].max() - dirarr[0].min()
334        mapy = dirarr[1].max() - dirarr[1].min()
335        nx = max(1, numpy.ceil(mapx/self.dirtol[0]))
336        ny = max(1, numpy.ceil(mapy/self.dirtol[0]))
337       
338        asaplog.push("Regrid output scantable with cell = [%s, %s]" % \
339                     (cellx, celly))
340        gridder.defineImage(nx=nx, ny=ny, cellx=cellx, celly=celly)
341        gridder.setFunc(func='box', width=1)
342        gridder.setWeight(weightType='uniform')
343        gridder.grid()
344        return gridder.getResult()
345       
346
347    @asaplog_post_dec
348    def _get_specarray(self, polid=None, beamid=None, dir=None):
349        ntable = len(self.tables)
350        spec_array = numpy.zeros((ntable, self.nchan), numpy.float)
351        nspec = 0
352        asaplog.push("Start data selection by POL=%d, BEAM=%d, direction=[%f, %f]" % (polid, beamid, dir[0], dir[1]))
353        tabidx = []
354        for itab in range(ntable):
355            tab = self.tables[itab]
356            # Select rows by POLNO and BEAMNO
357            try:
358                tab.set_selection(pols=[polid], beams=[beamid])
359                if tab.nrow() > 0: tabidx.append(itab)
360            except: # no selection
361                asaplog.post()
362                asaplog.push("table %d - No spectrum ....skipping the table" % (itab))
363                asaplog.post("WARN")
364                continue
365
366            # Select rows by direction
367            spec = numpy.zeros(self.nchan, numpy.float)
368            selrow = []
369            for irow in range(tab.nrow()):
370                currdir = tab.get_directionval(irow)
371                if (abs(currdir[0] - dir[0]) > self.dirtol[0]) or \
372                   (abs(currdir[1] - dir[1]) > self.dirtol[1]):
373                    continue
374                selrow.append(irow)
375            if len(selrow) == 0:
376                asaplog.post()
377                asaplog.push("table %d - No spectrum ....skipping the table" % (itab))
378                asaplog.post("WARN")
379                continue
380            else:
381                seltab = tab.copy()
382                seltab.set_selection(selector(rows=selrow))
383           
384            if tab.nrow() > 1:
385                asaplog.push("table %d - More than a spectrum selected. averaging rows..." % (itab))
386                tab = seltab.average_time(scanav=False, weight="tintsys")
387            else:
388                tab = seltab
389
390            spec_array[nspec] = tab._getspectrum()
391            nspec += 1
392
393        if nspec != ntable:
394            asaplog.post()
395            #asaplog.push("Some tables has no spectrum with POL=%d BEAM=%d. averaging rows..." % (polid, beamid))
396            asaplog.push("Could not find corresponding rows in some tables.")
397            asaplog.push("Number of spectra selected = %d (table: %d)" % (nspec, ntable))
398            if nspec < 2:
399                asaplog.push("At least 2 spectra are necessary for convolution")
400                asaplog.post("ERROR")
401                return False, tabidx
402
403        return spec_array[0:nspec], tabidx
404           
405
406    @asaplog_post_dec
407    def _setup_shift(self):
408        ### define self.tables, self.signalShift, and self.imageShift
409        if not self.intables:
410            asaplog.post()
411            raise RuntimeError, "Input data is not defined."
412        #if self.baseif < 0:
413        #    asaplog.post()
414        #    raise RuntimeError, "Reference IFNO is not defined."
415       
416        byname = False
417        #if not self.intables:
418        if isinstance(self.intables[0], str):
419            # A list of file name is given
420            if not os.path.exists(self.intables[0]):
421                asaplog.post()
422                raise RuntimeError, "Could not find '%s'" % self.intables[0]
423           
424            stab = scantable(self.intables[0],average=False)
425            ntab = len(self.intables)
426            byname = True
427        else:
428            stab = self.intables[0]
429            ntab = len(self.intables)
430
431        if len(stab.getbeamnos()) > 1:
432            asaplog.post()
433            asaplog.push("Mult-beam data is not supported by this module.")
434            asaplog.post("ERROR")
435            return
436
437        valid_ifs = stab.getifnos()
438        if self.baseif < 0:
439            self.baseif = valid_ifs[0]
440            asaplog.post()
441            asaplog.push("IFNO is not selected. Using the first IF in the first scantable. Reference IFNO = %d" % (self.baseif))
442       
443        if not (self.baseif in valid_ifs):
444            asaplog.post()
445            errmsg = "IF%d does not exist in the first scantable" %  \
446                     self.baseif
447            raise RuntimeError, errmsg
448
449        asaplog.push("Start selecting tables and IFNOs to solve.")
450        asaplog.push("Cheching frequency of the reference IF")
451        unit_org = stab.get_unit()
452        coord = stab._getcoordinfo()
453        frame_org = coord[1]
454        stab.set_unit("Hz")
455        if len(self.freqframe) > 0:
456            stab.set_freqframe(self.freqframe)
457        stab.set_selection(ifs=[self.baseif])
458        spx = stab._getabcissa()
459        stab.set_selection()
460        basech0 = spx[0]
461        baseinc = spx[1]-spx[0]
462        self.nchan = len(spx)
463        if isinstance(self.freqtol, float):
464            vftol = abs(baseinc * self.freqtol)
465            self.freqtol = dict(value=vftol, unit="Hz")
466        else:
467            vftol = abs(self.freqtol['value'])
468        inctol = abs(baseinc/float(self.nchan))
469        asaplog.push("Reference frequency setup (Table = 0, IFNO = %d):  nchan = %d, chan0 = %f Hz, incr = %f Hz" % (self.baseif, self.nchan, basech0, baseinc))
470        asaplog.push("Allowed frequency tolerance = %f Hz ( %f channels)" % (vftol, vftol/baseinc))
471        poltype0 = stab.poltype()
472       
473        self.tables = []
474        self.signalShift = []
475        if self.dsbmode:
476            self.imageShift = []
477
478        for itab in range(ntab):
479            asaplog.push("Table %d:" % itab)
480            tab_selected = False
481            if itab > 0:
482                if byname:
483                    stab = scantable(self.intables[itab],average=False)
484                    self.intables.append(stab)
485                else:
486                    stab = self.intables[itab]
487                unit_org = stab.get_unit()
488                coord = stab._getcoordinfo()
489                frame_org = coord[1]
490                stab.set_unit("Hz")
491                if len(self.freqframe) > 0:
492                    stab.set_freqframe(self.freqframe)
493
494            # Check POLTYPE should be equal to the first table.
495            if stab.poltype() != poltype0:
496                asaplog.post()
497                raise Exception, "POLTYPE should be equal to the first table."
498            # Multiple beam data may not handled properly
499            if len(stab.getbeamnos()) > 1:
500                asaplog.post()
501                asaplog.push("table contains multiple beams. It may not be handled properly.")
502                asaplog.push("WARN")
503           
504            for ifno in stab.getifnos():
505                stab.set_selection(ifs=[ifno])
506                spx = stab._getabcissa()
507                if (len(spx) != self.nchan) or \
508                   (abs(spx[0]-basech0) > vftol) or \
509                   (abs(spx[1]-spx[0]-baseinc) > inctol):
510                    continue
511                tab_selected = True
512                seltab = stab.copy()
513                seltab.set_unit(unit_org)
514                seltab.set_freqframe(frame_org)
515                self.tables.append(seltab)
516                self.signalShift.append((spx[0]-basech0)/baseinc)
517                if self.dsbmode:
518                    self.imageShift.append(-self.signalShift[-1])
519                asaplog.push("- IF%d selected: sideband shift = %16.12e channels" % (ifno, self.signalShift[-1]))
520            stab.set_selection()
521            stab.set_unit(unit_org)
522            stab.set_freqframe(frame_org)
523            if not tab_selected:
524                asaplog.post()
525                asaplog.push("No data selected in Table %d" % itab)
526                asaplog.post("WARN")
527
528        asaplog.push("Total number of IFs selected = %d" % len(self.tables))
529        if len(self.tables) < 2:
530            asaplog.post()
531            raise RuntimeError, "At least 2 IFs are necessary for convolution!"
532
533        if not self.dsbmode and len(self.imageShift) != len(self.signalShift):
534            asaplog.post()
535            errmsg = "User defined channel shift of image sideband has %d elements, while selected IFNOs are %d" % (len(self.imageShift), len(self.signalShift))
536            raise RuntimeError, errmsg
537
538        self.signalShift = numpy.array(self.signalShift)
539        self.imageShift = numpy.array(self.imageShift)
540        self.nshift = len(self.tables)
541
542    @asaplog_post_dec
543    def _preprocess_tables(self):
544        ### temporary method to preprocess data
545        ### Do time averaging for now.
546        for itab in range(len(self.tables)):
547            self.tables[itab] = self.tables[itab].average_time(scanav=False, weight="tintsys")
548       
549
550#     def save(self, outfile, outform="ASAP", overwrite=False):
551#         if not overwrite and os.path.exists(outfile):
552#             raise RuntimeError, "Output file '%s' already exists" % outfile
553#
554#         #self.separator._save(outfile, outform)
555
556#     def done(self):
557#         self.close()
558
559#     def close(self):
560#         pass
561#         #del self.separator
562   
563
564
565########################################################################
566    def _Deconvolution(self, data_array, shift, threshold=0.00000001):
567        FObs = []
568        Reject = 0
569        nshift, nchan = data_array.shape
570        nspec = nshift*(nshift-1)/2
571        ifftObs  = numpy.zeros((nspec, nchan), numpy.float)
572        for i in range(nshift):
573           F = FFT.fft(data_array[i])
574           FObs.append(F)
575        z = 0
576        for i in range(nshift):
577            for j in range(i+1, nshift):
578                Fobs = (FObs[i]+FObs[j])/2.0
579                dX = (shift[j]-shift[i])*2.0*math.pi/float(self.nchan)
580                #print 'dX,i,j=',dX,i,j
581                for k in range(1,self.nchan):
582                    if math.fabs(math.sin(dX*k)) > threshold:
583                        Fobs[k] += ((FObs[i][k]-FObs[j][k])/2.0/(1.0-math.cos(dX*k))*math.sin(dX*k))*1.0j
584                    else: Reject += 1
585                ifftObs[z] = FFT.ifft(Fobs)
586                z += 1
587        print 'Threshold=%s Reject=%d' % (threshold, Reject)
588        return ifftObs
589
590    def _combineResult(self, ifftObs):
591        nspec = len(ifftObs)
592        sum = ifftObs[0]
593        for i in range(1,nspec):
594            sum += ifftObs[i]
595        return(sum/float(nspec))
596
597    def _subtractOtherSide(self, data_array, shift, Data):
598        sum = numpy.zeros(len(Data), numpy.float)
599        numSP = len(data_array)
600        for i in range(numSP):
601            SPsub = data_array[i] - Data
602            sum += self._shiftSpectrum(SPsub, -shift[i])
603        return(sum/float(numSP))
604
605    def _shiftSpectrum(self, data, Shift):
606        Out = numpy.zeros(self.nchan, numpy.float)
607        w2 = Shift % 1
608        w1 = 1.0 - w2
609        for i in range(self.nchan):
610            c1 = int((Shift + i) % self.nchan)
611            c2 = (c1 + 1) % self.nchan
612            Out[c1] += data[i] * w1
613            Out[c2] += data[i] * w2
614        return Out.copy()
Note: See TracBrowser for help on using the repository browser.