source: trunk/python/asapfitter.py @ 515

Last change on this file since 515 was 515, checked in by mar637, 19 years ago
  • major rework on plotting.
  • added component selection and plotting
  • added wrapper function to set parameters
  • addedd formatting of parameter print out
  • Property svn:eol-style set to native
  • Property svn:keywords set to Author Date Id Revision
File size: 13.0 KB
Line 
1import _asap
2from asap import rcParams
3
4class fitter:
5    """
6    The fitting class for ASAP.
7    """
8    def _verbose(self, *args):
9        """
10        Set stdout output.
11        """
12        if type(args[0]) is bool:
13            self._vb = args[0]
14            return
15        elif len(args) == 0:
16            return self._vb
17       
18    def __init__(self):
19        """
20        Create a fitter object. No state is set.
21        """
22        self.fitter = _asap.fitter()
23        self.x = None
24        self.y = None
25        self.mask = None
26        self.fitfunc = None
27        self.fitfuncs = None
28        self.fitted = False
29        self.data = None
30        self.components = 0
31        self._fittedrow = 0
32        self._p = None
33        self._vb = True
34        self._selection = None
35
36    def set_data(self, xdat, ydat, mask=None):
37        """
38        Set the absissa and ordinate for the fit. Also set the mask
39        indicationg valid points.
40        This can be used for data vectors retrieved from a scantable.
41        For scantable fitting use 'fitter.set_scan(scan, mask)'.
42        Parameters:
43            xdat:    the abcissa values
44            ydat:    the ordinate values
45            mask:    an optional mask
46       
47        """
48        self.fitted = False
49        self.x = xdat
50        self.y = ydat
51        if mask == None:
52            from numarray import ones
53            self.mask = ones(len(xdat))
54        else:
55            self.mask = mask
56        return
57
58    def set_scan(self, thescan=None, mask=None):
59        """
60        Set the 'data' (a scantable) of the fitter.
61        Parameters:
62            thescan:     a scantable
63            mask:        a msk retireved from the scantable
64        """
65        if not thescan:
66            print "Please give a correct scan"
67        self.fitted = False
68        self.data = thescan
69        if mask is None:
70            from numarray import ones
71            self.mask = ones(self.data.nchan())
72        else:
73            self.mask = mask
74        return
75
76    def set_function(self, **kwargs):
77        """
78        Set the function to be fit.
79        Parameters:
80            poly:    use a polynomial of the order given
81            gauss:   fit the number of gaussian specified
82        Example:
83            fitter.set_function(gauss=2) # will fit two gaussians
84            fitter.set_function(poly=3)  # will fit a 3rd order polynomial
85        """
86        #default poly order 0       
87        n=0
88        if kwargs.has_key('poly'):
89            self.fitfunc = 'poly'
90            n = kwargs.get('poly')
91            self.components = [n]
92        elif kwargs.has_key('gauss'):
93            n = kwargs.get('gauss')
94            self.fitfunc = 'gauss'
95            self.fitfuncs = [ 'gauss' for i in range(n) ]
96            self.components = [ 3 for i in range(n) ]
97        else:
98            print "Invalid function type."
99            return
100        self.fitter.setexpression(self.fitfunc,n)
101        return
102           
103    def fit(self, row=0):
104        """
105        Execute the actual fitting process. All the state has to be set.
106        Parameters:
107            none
108        Example:
109            s = scantable('myscan.asap')
110            s.set_cursor(thepol=1)        # select second pol
111            f = fitter()
112            f.set_scan(s)
113            f.set_function(poly=0)
114            f.fit(row=0)                  # fit first row
115        """
116        if ((self.x is None or self.y is None) and self.data is None) \
117               or self.fitfunc is None:
118            print "Fitter not yet initialised. Please set data & fit function"
119            return
120        else:
121            if self.data is not None:
122                self.x = self.data._getabcissa(row)
123                self.y = self.data._getspectrum(row)
124                print "Fitting:"
125                vb = self.data._vb
126                self.data._vb = True
127                self.selection = self.data.get_cursor()
128                self.data._vb = vb
129        self.fitter.setdata(self.x, self.y, self.mask)
130        if self.fitfunc == 'gauss':
131            ps = self.fitter.getparameters()
132            if len(ps) == 0:
133                self.fitter.estimate()
134        self.fitter.fit()
135        self._fittedrow = row
136        self.fitted = True
137        return
138
139    def store_fit(self):
140        if self.fitted and self.data is not None:
141            pars = list(self.fitter.getparameters())
142            fixed = list(self.fitter.getfixedparameters())
143            self.data._addfit(self._fittedrow, pars, fixed,
144                              self.fitfuncs, self.components)
145
146    def set_parameters(self, params, fixed=None, component=None):
147        if self.fitfunc is None:
148            print "Please specify a fitting function first."
149            return
150        if self.fitfunc == "gauss" and component is not None:
151            if not self.fitted:
152                from numarray import zeros
153                pars = list(zeros(len(self.components)*3))
154                fxd = list(zeros(len(pars)))
155            else:
156                pars = list(self.fitter.getparameters())             
157                fxd = list(self.fitter.getfixedparameters())
158            i = 3*component
159            pars[i:i+3] = params
160            fxd[i:i+3] = fixed
161            params = pars
162            fixed = fxd         
163        self.fitter.setparameters(params)
164        if fixed is not None:
165            self.fitter.setfixedparameters(fixed)
166        return
167
168    def set_gauss_parameters(self, peak, centre, fhwm,
169                             peakfixed=False, centerfixed=False,
170                             fhwmfixed=False,
171                             component=0):
172        """
173        Set the Parameters of a 'Gaussian' component, set with set_function.
174        Parameters:
175            component:           The number of the component (Default is the
176                                 first component.
177            peak, centre, fhwm:  The gaussian parameters
178            peakfixed,
179            centerfixed,
180            fhwmfixed:           Optional parameters to indicate if
181                                 the paramters should be held fixed during
182                                 the fitting process. The default is to keep
183                                 all parameters flexible.
184        """
185        if self.fitfunc != "gauss":
186            print "Function only operates on Gaussian components."
187            return
188        if 0 <= component < len(self.components):
189            self.set_parameters([peak, centre, fhwm],
190                                [peakfixed, centerfixed, fhwmfixed],
191                                component)
192        else:
193            print "Please select a valid  component."
194            return
195       
196    def get_parameters(self, component=None):
197        """
198        Return the fit paramters.
199       
200        """
201        if not self.fitted:
202            print "Not yet fitted."
203        pars = list(self.fitter.getparameters())
204        fixed = list(self.fitter.getfixedparameters())
205        if component is not None:           
206            if self.fitfunc == "gauss":
207                i = 3*component
208                cpars = pars[i:i+3]
209                cfixed = fixed[i:i+3]
210            else:
211                cpars = pars
212                cfixed = fixed               
213        else:
214            cpars = pars
215            cfixed = fixed
216        fpars = self._format_pars(cpars, cfixed)
217        if self._vb:
218            print fpars
219        return cpars, cfixed, fpars
220   
221    def _format_pars(self, pars, fixed):
222        out = ''
223        if self.fitfunc == 'poly':
224            c = 0
225            for i in range(len(pars)):
226                fix = ""
227                if fixed[i]: fix = "(fixed)"
228                out += '  p%d%s= %3.3f,' % (c,fix,pars[i])
229                c+=1
230            out = out[:-1]  # remove trailing ','
231        elif self.fitfunc == 'gauss':
232            i = 0
233            c = 0
234            aunit = ''
235            ounit = ''
236            if self.data:
237                aunit = self.data.get_unit()
238                ounit = self.data.get_fluxunit()
239            while i < len(pars):
240                out += '  %d: peak = %3.3f %s , centre = %3.3f %s, FWHM = %3.3f %s \n' % (c,pars[i],ounit,pars[i+1],aunit,pars[i+2],aunit)
241                c+=1
242                i+=3
243        return out
244       
245    def get_estimate(self):
246        """
247        Return the parameter estimates (for non-linear functions).
248        """
249        pars = self.fitter.getestimate()
250        if self._vb:
251            print self._format_pars(pars)
252        return pars
253       
254
255    def get_residual(self):
256        """
257        Return the residual of the fit.
258        """
259        if not self.fitted:
260            print "Not yet fitted."
261        return self.fitter.getresidual()
262
263    def get_chi2(self):
264        """
265        Return chi^2.
266        """
267        if not self.fitted:
268            print "Not yet fitted."
269        ch2 = self.fitter.getchi2()
270        if self._vb:
271            print 'Chi^2 = %3.3f' % (ch2)
272        return ch2
273
274    def get_fit(self):
275        """
276        Return the fitted ordinate values.
277        """
278        if not self.fitted:
279            print "Not yet fitted."
280        return self.fitter.getfit()
281
282    def commit(self):
283        """
284        Return a new scan where the fits have been commited.
285        """
286        if not self.fitted:
287            print "Not yet fitted."
288        if self.data is not scantable:
289            print "Only works with scantables"
290            return
291        scan = self.data.copy()
292        scan._setspectrum(self.fitter.getresidual())
293
294    def plot(self, residual=False, components=None, plotparms=False,
295             plotrange=None):
296        """
297        Plot the last fit.
298        Parameters:
299            residual:    an optional parameter indicating if the residual
300                         should be plotted (default 'False')
301        """
302        if not self.fitted:
303            return
304        if not self._p:
305            from asap.asaplot import ASAPlot
306            self._p = ASAPlot()
307        if self._p.is_dead:
308            from asap.asaplot import ASAPlot
309            self._p = ASAPlot()
310        self._p.clear()
311        self._p.set_panels()
312        self._p.palette(1)
313        tlab = 'Spectrum'
314        xlab = 'Abcissa'       
315        m = ()
316        if self.data:
317            tlab = self.data._getsourcename(self._fittedrow)
318            xlab = self.data._getabcissalabel(self._fittedrow)
319            m = self.data._getmask(self._fittedrow)
320            ylab = r'Flux'
321
322        colours = ["grey60","grey80","red","orange","purple","yellow","magenta", "cyan"]
323        self._p.palette(1,colours)
324        self._p.set_line(label='Spectrum')
325        self._p.plot(self.x, self.y, m)
326        if residual:
327            self._p.palette(2)
328            self._p.set_line(label='Residual')
329            self._p.plot(self.x, self.get_residual(), m)
330        self._p.palette(3)
331        if components is not None:
332            cs = components
333            if isinstance(components,int): cs = [components]
334            self._p.text(0.15,0.15,str(self.get_parameters()[2]),size=8)
335            n = len(self.components)
336            self._p.palette(4)
337            for c in cs:
338                if 0 <= c < n:
339                    lab = self.fitfuncs[c]+str(c)
340                    self._p.set_line(label=lab)
341                    self._p.plot(self.x, self.fitter.evaluate(c), m)
342                elif c == -1:
343                    self._p.palette(3)
344                    self._p.set_line(label="Total Fit")
345                    self._p.plot(self.x, self.get_fit(), m)                   
346        else:
347            self._p.palette(3)
348            self._p.set_line(label='Fit')
349            self._p.plot(self.x, self.get_fit(), m)
350        self._p.set_axes('xlabel',xlab)
351        self._p.set_axes('ylabel',ylab)
352        self._p.set_axes('title',tlab)
353        self._p.release()
354
355    def auto_fit(self, insitu=None):
356        """
357        Return a scan where the function is applied to all rows for
358        all Beams/IFs/Pols.
359       
360        """
361        from asap import scantable
362        if not isinstance(self.data, scantable) :
363            print "Only works with scantables"
364            return
365        if insitu is None: insitu = rcParams['insitu']
366        if not insitu:
367            scan = self.data.copy()
368        else:
369            scan = self.data
370        vb = scan._vb
371        scan._vb = False
372        sel = scan.get_cursor()
373        rows = range(scan.nrow())
374        for i in range(scan.nbeam()):
375            scan.setbeam(i)
376            for j in range(scan.nif()):
377                scan.setif(j)
378                for k in range(scan.npol()):
379                    scan.setpol(k)
380                    if self._vb:
381                        print "Fitting:"
382                        print 'Beam[%d], IF[%d], Pol[%d]' % (i,j,k)
383                    for iRow in rows:
384                        self.x = scan._getabcissa(iRow)
385                        self.y = scan._getspectrum(iRow)
386                        self.data = None
387                        self.fit()                   
388                        x = self.get_parameters()
389                        scan._setspectrum(self.fitter.getresidual(),iRow)
390        scan.set_cursor(sel[0],sel[1],sel[2])
391        scan._vb = vb
392        if not insitu:
393            return scan
Note: See TracBrowser for help on using the repository browser.