source: branches/alma/python/asapfitter.py@ 1612

Last change on this file since 1612 was 1612, checked in by Takeshi Nakazato, 15 years ago

New Development: No

JIRA Issue: Yes CAS-729, CAS-1147

Ready to Release: Yes

Interface Changes: No

What Interface Changed: Please list interface changes

Test Programs: List test programs

Put in Release Notes: Yes

Module(s): Module Names change impacts.

Description: Describe your changes here...

I have changed that almost all log messages are output to casapy.log,
not to the terminal window. After this change, asap becomes to depend on casapy
and is not running in standalone, because asap have to import taskinit module
to access casalogger.


  • Property svn:eol-style set to native
  • Property svn:keywords set to Author Date Id Revision
File size: 23.4 KB
RevLine 
[113]1import _asap
[259]2from asap import rcParams
[723]3from asap import print_log
[1295]4from asap import _n_bools
[1603]5from asap import mask_and
[1612]6from taskinit import *
[113]7
8class fitter:
9 """
10 The fitting class for ASAP.
11 """
[723]12
[113]13 def __init__(self):
14 """
15 Create a fitter object. No state is set.
16 """
17 self.fitter = _asap.fitter()
18 self.x = None
19 self.y = None
20 self.mask = None
21 self.fitfunc = None
[515]22 self.fitfuncs = None
[113]23 self.fitted = False
24 self.data = None
[515]25 self.components = 0
26 self._fittedrow = 0
[113]27 self._p = None
[515]28 self._selection = None
[1389]29 self.uselinear = False
[113]30
31 def set_data(self, xdat, ydat, mask=None):
32 """
[158]33 Set the absissa and ordinate for the fit. Also set the mask
[113]34 indicationg valid points.
35 This can be used for data vectors retrieved from a scantable.
36 For scantable fitting use 'fitter.set_scan(scan, mask)'.
37 Parameters:
[158]38 xdat: the abcissa values
[113]39 ydat: the ordinate values
40 mask: an optional mask
[723]41
[113]42 """
43 self.fitted = False
44 self.x = xdat
45 self.y = ydat
46 if mask == None:
[1295]47 self.mask = _n_bools(len(xdat), True)
[113]48 else:
49 self.mask = mask
50 return
51
52 def set_scan(self, thescan=None, mask=None):
53 """
54 Set the 'data' (a scantable) of the fitter.
55 Parameters:
56 thescan: a scantable
[1603]57 mask: a msk retrieved from the scantable
[113]58 """
59 if not thescan:
[723]60 msg = "Please give a correct scan"
61 if rcParams['verbose']:
[1612]62 #print msg
63 casalog.post( msg, 'WARN' )
[723]64 return
65 else:
66 raise TypeError(msg)
[113]67 self.fitted = False
68 self.data = thescan
[1075]69 self.mask = None
[113]70 if mask is None:
[1295]71 self.mask = _n_bools(self.data.nchan(), True)
[113]72 else:
73 self.mask = mask
74 return
75
76 def set_function(self, **kwargs):
77 """
78 Set the function to be fit.
79 Parameters:
[1389]80 poly: use a polynomial of the order given with nonlinear least squares fit
81 lpoly: use polynomial of the order given with linear least squares fit
[113]82 gauss: fit the number of gaussian specified
83 Example:
84 fitter.set_function(gauss=2) # will fit two gaussians
[1389]85 fitter.set_function(poly=3) # will fit a 3rd order polynomial via nonlinear method
86 fitter.set_function(lpoly=3) # will fit a 3rd order polynomial via linear method
[113]87 """
[723]88 #default poly order 0
[515]89 n=0
[113]90 if kwargs.has_key('poly'):
91 self.fitfunc = 'poly'
92 n = kwargs.get('poly')
[515]93 self.components = [n]
[1389]94 self.uselinear = False
95 elif kwargs.has_key('lpoly'):
96 self.fitfunc = 'poly'
97 n = kwargs.get('lpoly')
98 self.components = [n]
99 self.uselinear = True
[113]100 elif kwargs.has_key('gauss'):
101 n = kwargs.get('gauss')
102 self.fitfunc = 'gauss'
[515]103 self.fitfuncs = [ 'gauss' for i in range(n) ]
104 self.components = [ 3 for i in range(n) ]
[1389]105 self.uselinear = False
[515]106 else:
[723]107 msg = "Invalid function type."
108 if rcParams['verbose']:
[1612]109 #print msg
110 casalog.post( msg, 'WARN' )
[723]111 return
112 else:
113 raise TypeError(msg)
114
[113]115 self.fitter.setexpression(self.fitfunc,n)
[1232]116 self.fitted = False
[113]117 return
[723]118
[1075]119 def fit(self, row=0, estimate=False):
[113]120 """
121 Execute the actual fitting process. All the state has to be set.
122 Parameters:
[1075]123 row: specify the row in the scantable
124 estimate: auto-compute an initial parameter set (default False)
125 This can be used to compute estimates even if fit was
126 called before.
[113]127 Example:
[515]128 s = scantable('myscan.asap')
129 s.set_cursor(thepol=1) # select second pol
[113]130 f = fitter()
131 f.set_scan(s)
132 f.set_function(poly=0)
[723]133 f.fit(row=0) # fit first row
[113]134 """
135 if ((self.x is None or self.y is None) and self.data is None) \
136 or self.fitfunc is None:
[723]137 msg = "Fitter not yet initialised. Please set data & fit function"
138 if rcParams['verbose']:
[1612]139 #print msg
140 casalog.post( msg, 'WARN' )
[723]141 return
142 else:
143 raise RuntimeError(msg)
144
[113]145 else:
146 if self.data is not None:
[515]147 self.x = self.data._getabcissa(row)
148 self.y = self.data._getspectrum(row)
[1603]149 self.mask = mask_and(self.mask, self.data._getmask(row))
[723]150 from asap import asaplog
151 asaplog.push("Fitting:")
[943]152 i = row
[1603]153 out = "Scan[%d] Beam[%d] IF[%d] Pol[%d] Cycle[%d]" % (self.data.getscan(i),
154 self.data.getbeam(i),
155 self.data.getif(i),
156 self.data.getpol(i),
157 self.data.getcycle(i))
[1075]158 asaplog.push(out,False)
[515]159 self.fitter.setdata(self.x, self.y, self.mask)
[113]160 if self.fitfunc == 'gauss':
161 ps = self.fitter.getparameters()
[1075]162 if len(ps) == 0 or estimate:
[113]163 self.fitter.estimate()
[626]164 try:
[1232]165 fxdpar = list(self.fitter.getfixedparameters())
166 if len(fxdpar) and fxdpar.count(0) == 0:
167 raise RuntimeError,"No point fitting, if all parameters are fixed."
[1389]168 if self.uselinear:
169 converged = self.fitter.lfit()
170 else:
171 converged = self.fitter.fit()
[1075]172 if not converged:
173 raise RuntimeError,"Fit didn't converge."
[626]174 except RuntimeError, msg:
[723]175 if rcParams['verbose']:
[1612]176 #print msg
177 casalog.post( msg, 'WARN' )
[723]178 else:
179 raise
[515]180 self._fittedrow = row
[113]181 self.fitted = True
[723]182 print_log()
[113]183 return
184
[1232]185 def store_fit(self, filename=None):
[526]186 """
[1232]187 Save the fit parameters.
188 Parameters:
189 filename: if specified save as an ASCII file, if None (default)
190 store it in the scnatable
[526]191 """
[515]192 if self.fitted and self.data is not None:
193 pars = list(self.fitter.getparameters())
194 fixed = list(self.fitter.getfixedparameters())
[975]195 from asap.asapfit import asapfit
196 fit = asapfit()
197 fit.setparameters(pars)
198 fit.setfixedparameters(fixed)
199 fit.setfunctions(self.fitfuncs)
200 fit.setcomponents(self.components)
201 fit.setframeinfo(self.data._getcoordinfo())
[1232]202 if filename is not None:
203 import os
204 filename = os.path.expandvars(os.path.expanduser(filename))
205 if os.path.exists(filename):
206 raise IOError("File '%s' exists." % filename)
207 fit.save(filename)
208 else:
209 self.data._addfit(fit,self._fittedrow)
[515]210
[1017]211 #def set_parameters(self, params, fixed=None, component=None):
212 def set_parameters(self,*args,**kwargs):
[526]213 """
214 Set the parameters to be fitted.
215 Parameters:
216 params: a vector of parameters
217 fixed: a vector of which parameters are to be held fixed
218 (default is none)
219 component: in case of multiple gaussians, the index of the
220 component
[1017]221 """
222 component = None
223 fixed = None
224 params = None
[1031]225
[1017]226 if len(args) and isinstance(args[0],dict):
227 kwargs = args[0]
228 if kwargs.has_key("fixed"): fixed = kwargs["fixed"]
229 if kwargs.has_key("params"): params = kwargs["params"]
230 if len(args) == 2 and isinstance(args[1], int):
231 component = args[1]
[515]232 if self.fitfunc is None:
[723]233 msg = "Please specify a fitting function first."
234 if rcParams['verbose']:
[1612]235 #print msg
236 casalog.post( msg, 'WARN' )
[723]237 return
238 else:
239 raise RuntimeError(msg)
[515]240 if self.fitfunc == "gauss" and component is not None:
[1017]241 if not self.fitted and sum(self.fitter.getparameters()) == 0:
[1295]242 pars = _n_bools(len(self.components)*3, False)
243 fxd = _n_bools(len(pars), False)
[515]244 else:
[723]245 pars = list(self.fitter.getparameters())
[515]246 fxd = list(self.fitter.getfixedparameters())
247 i = 3*component
248 pars[i:i+3] = params
249 fxd[i:i+3] = fixed
250 params = pars
[723]251 fixed = fxd
[113]252 self.fitter.setparameters(params)
253 if fixed is not None:
254 self.fitter.setfixedparameters(fixed)
[723]255 print_log()
[113]256 return
[515]257
[1217]258 def set_gauss_parameters(self, peak, centre, fwhm,
[1603]259 peakfixed=0, centrefixed=0,
[1217]260 fwhmfixed=0,
[515]261 component=0):
[113]262 """
[515]263 Set the Parameters of a 'Gaussian' component, set with set_function.
264 Parameters:
[1232]265 peak, centre, fwhm: The gaussian parameters
[515]266 peakfixed,
[1603]267 centrefixed,
[1217]268 fwhmfixed: Optional parameters to indicate if
[515]269 the paramters should be held fixed during
270 the fitting process. The default is to keep
271 all parameters flexible.
[526]272 component: The number of the component (Default is the
273 component 0)
[515]274 """
275 if self.fitfunc != "gauss":
[723]276 msg = "Function only operates on Gaussian components."
277 if rcParams['verbose']:
[1612]278 #print msg
279 casalog.post( msg, 'WARN' )
[723]280 return
281 else:
282 raise ValueError(msg)
[515]283 if 0 <= component < len(self.components):
[1217]284 d = {'params':[peak, centre, fwhm],
[1603]285 'fixed':[peakfixed, centrefixed, fwhmfixed]}
[1017]286 self.set_parameters(d, component)
[515]287 else:
[723]288 msg = "Please select a valid component."
289 if rcParams['verbose']:
[1612]290 #print msg
291 casalog.post( msg, 'WARN' )
[723]292 return
293 else:
294 raise ValueError(msg)
295
[975]296 def get_area(self, component=None):
297 """
298 Return the area under the fitted gaussian component.
299 Parameters:
300 component: the gaussian component selection,
301 default (None) is the sum of all components
302 Note:
303 This will only work for gaussian fits.
304 """
305 if not self.fitted: return
306 if self.fitfunc == "gauss":
307 pars = list(self.fitter.getparameters())
308 from math import log,pi,sqrt
309 fac = sqrt(pi/log(16.0))
310 areas = []
311 for i in range(len(self.components)):
312 j = i*3
313 cpars = pars[j:j+3]
314 areas.append(fac * cpars[0] * cpars[2])
315 else:
316 return None
317 if component is not None:
318 return areas[component]
319 else:
320 return sum(areas)
321
[1075]322 def get_errors(self, component=None):
[515]323 """
[1075]324 Return the errors in the parameters.
325 Parameters:
326 component: get the errors for the specified component
327 only, default is all components
328 """
329 if not self.fitted:
330 msg = "Not yet fitted."
331 if rcParams['verbose']:
[1612]332 #print msg
333 casalog.post( msg, 'WARN' )
[1075]334 return
335 else:
336 raise RuntimeError(msg)
337 errs = list(self.fitter.geterrors())
338 cerrs = errs
339 if component is not None:
340 if self.fitfunc == "gauss":
341 i = 3*component
342 if i < len(errs):
343 cerrs = errs[i:i+3]
344 return cerrs
345
346 def get_parameters(self, component=None, errors=False):
347 """
[113]348 Return the fit paramters.
[526]349 Parameters:
350 component: get the parameters for the specified component
351 only, default is all components
[113]352 """
353 if not self.fitted:
[723]354 msg = "Not yet fitted."
355 if rcParams['verbose']:
[1612]356 #print msg
357 casalog.post( msg, 'WARN' )
[723]358 return
359 else:
360 raise RuntimeError(msg)
[113]361 pars = list(self.fitter.getparameters())
362 fixed = list(self.fitter.getfixedparameters())
[1075]363 errs = list(self.fitter.geterrors())
[1039]364 area = []
[723]365 if component is not None:
[515]366 if self.fitfunc == "gauss":
367 i = 3*component
368 cpars = pars[i:i+3]
369 cfixed = fixed[i:i+3]
[1075]370 cerrs = errs[i:i+3]
[1039]371 a = self.get_area(component)
372 area = [a for i in range(3)]
[515]373 else:
374 cpars = pars
[723]375 cfixed = fixed
[1075]376 cerrs = errs
[515]377 else:
378 cpars = pars
379 cfixed = fixed
[1075]380 cerrs = errs
[1039]381 if self.fitfunc == "gauss":
382 for c in range(len(self.components)):
383 a = self.get_area(c)
384 area += [a for i in range(3)]
[1088]385 fpars = self._format_pars(cpars, cfixed, errors and cerrs, area)
[723]386 if rcParams['verbose']:
[1612]387 #print fpars
388 casalog.post( fpars )
[1075]389 return {'params':cpars, 'fixed':cfixed, 'formatted': fpars,
390 'errors':cerrs}
[723]391
[1075]392 def _format_pars(self, pars, fixed, errors, area):
[113]393 out = ''
394 if self.fitfunc == 'poly':
395 c = 0
[515]396 for i in range(len(pars)):
397 fix = ""
[1232]398 if len(fixed) and fixed[i]: fix = "(fixed)"
[1088]399 if errors :
400 out += ' p%d%s= %3.6f (%1.6f),' % (c,fix,pars[i], errors[i])
401 else:
402 out += ' p%d%s= %3.6f,' % (c,fix,pars[i])
[113]403 c+=1
[515]404 out = out[:-1] # remove trailing ','
[113]405 elif self.fitfunc == 'gauss':
406 i = 0
407 c = 0
[515]408 aunit = ''
409 ounit = ''
[113]410 if self.data:
[515]411 aunit = self.data.get_unit()
412 ounit = self.data.get_fluxunit()
[113]413 while i < len(pars):
[1039]414 if len(area):
415 out += ' %2d: peak = %3.3f %s , centre = %3.3f %s, FWHM = %3.3f %s\n area = %3.3f %s %s\n' % (c,pars[i],ounit,pars[i+1],aunit,pars[i+2],aunit, area[i],ounit,aunit)
[1017]416 else:
417 out += ' %2d: peak = %3.3f %s , centre = %3.3f %s, FWHM = %3.3f %s\n' % (c,pars[i],ounit,pars[i+1],aunit,pars[i+2],aunit,ounit,aunit)
[113]418 c+=1
419 i+=3
420 return out
[723]421
[113]422 def get_estimate(self):
423 """
[515]424 Return the parameter estimates (for non-linear functions).
[113]425 """
426 pars = self.fitter.getestimate()
[943]427 fixed = self.fitter.getfixedparameters()
[723]428 if rcParams['verbose']:
[1612]429 #print self._format_pars(pars,fixed,None)
430 casalog.post( self._format_pars(pars,fixed,None) )
[113]431 return pars
432
433 def get_residual(self):
434 """
435 Return the residual of the fit.
436 """
437 if not self.fitted:
[723]438 msg = "Not yet fitted."
439 if rcParams['verbose']:
[1612]440 #print msg
441 casalog.post( msg, 'WARN' )
[723]442 return
443 else:
444 raise RuntimeError(msg)
[113]445 return self.fitter.getresidual()
446
447 def get_chi2(self):
448 """
449 Return chi^2.
450 """
451 if not self.fitted:
[723]452 msg = "Not yet fitted."
453 if rcParams['verbose']:
[1612]454 #print msg
455 casalog.post( msg, 'WARN' )
[723]456 return
457 else:
458 raise RuntimeError(msg)
[113]459 ch2 = self.fitter.getchi2()
[723]460 if rcParams['verbose']:
[1612]461 #print 'Chi^2 = %3.3f' % (ch2)
462 casalog.post( 'Chi^2 = %3.3f' % (ch2) )
[723]463 return ch2
[113]464
465 def get_fit(self):
466 """
467 Return the fitted ordinate values.
468 """
469 if not self.fitted:
[723]470 msg = "Not yet fitted."
471 if rcParams['verbose']:
[1612]472 #print msg
473 casalog.post( msg, 'WARN' )
[723]474 return
475 else:
476 raise RuntimeError(msg)
[113]477 return self.fitter.getfit()
478
479 def commit(self):
480 """
[526]481 Return a new scan where the fits have been commited (subtracted)
[113]482 """
483 if not self.fitted:
[723]484 msg = "Not yet fitted."
485 if rcParams['verbose']:
[1612]486 #print msg
487 casalog.post( msg, 'WARN' )
[723]488 return
489 else:
490 raise RuntimeError(msg)
[975]491 from asap import scantable
492 if not isinstance(self.data, scantable):
[723]493 msg = "Not a scantable"
494 if rcParams['verbose']:
[1612]495 #print msg
496 casalog.post( msg, 'WARN' )
[723]497 return
498 else:
499 raise TypeError(msg)
[113]500 scan = self.data.copy()
[259]501 scan._setspectrum(self.fitter.getresidual())
[723]502 print_log()
[1092]503 return scan
[113]504
[723]505 def plot(self, residual=False, components=None, plotparms=False, filename=None):
[113]506 """
507 Plot the last fit.
508 Parameters:
509 residual: an optional parameter indicating if the residual
510 should be plotted (default 'False')
[526]511 components: a list of components to plot, e.g [0,1],
512 -1 plots the total fit. Default is to only
513 plot the total fit.
514 plotparms: Inidicates if the parameter values should be present
515 on the plot
[113]516 """
517 if not self.fitted:
518 return
[723]519 if not self._p or self._p.is_dead:
520 if rcParams['plotter.gui']:
521 from asap.asaplotgui import asaplotgui as asaplot
522 else:
523 from asap.asaplot import asaplot
524 self._p = asaplot()
525 self._p.hold()
[113]526 self._p.clear()
[515]527 self._p.set_panels()
[652]528 self._p.palette(0)
[113]529 tlab = 'Spectrum'
[723]530 xlab = 'Abcissa'
[1017]531 ylab = 'Ordinate'
[1273]532 from matplotlib.numerix import ma,logical_not,logical_and,array
533 m = self.mask
[113]534 if self.data:
[515]535 tlab = self.data._getsourcename(self._fittedrow)
536 xlab = self.data._getabcissalabel(self._fittedrow)
[1273]537 m = logical_and(self.mask,
[1306]538 array(self.data._getmask(self._fittedrow),
539 copy=False))
[1389]540
[626]541 ylab = self.data._get_ordinate_label()
[515]542
[1075]543 colours = ["#777777","#dddddd","red","orange","purple","green","magenta", "cyan"]
[1461]544 nomask=True
545 for i in range(len(m)):
546 nomask = nomask and m[i]
547 label0='Masked Region'
548 label1='Spectrum'
549 if ( nomask ):
550 label0=label1
551 else:
552 y = ma.masked_array( self.y, mask = m )
553 self._p.palette(1,colours)
554 self._p.set_line( label = label1 )
555 self._p.plot( self.x, y )
[652]556 self._p.palette(0,colours)
[1461]557 self._p.set_line(label=label0)
[1273]558 y = ma.masked_array(self.y,mask=logical_not(m))
[1088]559 self._p.plot(self.x, y)
[113]560 if residual:
[1461]561 self._p.palette(7)
[515]562 self._p.set_line(label='Residual')
[1116]563 y = ma.masked_array(self.get_residual(),
[1273]564 mask=logical_not(m))
[1088]565 self._p.plot(self.x, y)
[652]566 self._p.palette(2)
[515]567 if components is not None:
568 cs = components
569 if isinstance(components,int): cs = [components]
[526]570 if plotparms:
[1031]571 self._p.text(0.15,0.15,str(self.get_parameters()['formatted']),size=8)
[515]572 n = len(self.components)
[652]573 self._p.palette(3)
[515]574 for c in cs:
575 if 0 <= c < n:
576 lab = self.fitfuncs[c]+str(c)
577 self._p.set_line(label=lab)
[1116]578 y = ma.masked_array(self.fitter.evaluate(c),
[1273]579 mask=logical_not(m))
[1088]580
581 self._p.plot(self.x, y)
[515]582 elif c == -1:
[652]583 self._p.palette(2)
[515]584 self._p.set_line(label="Total Fit")
[1116]585 y = ma.masked_array(self.fitter.getfit(),
[1273]586 mask=logical_not(m))
[1088]587 self._p.plot(self.x, y)
[515]588 else:
[652]589 self._p.palette(2)
[515]590 self._p.set_line(label='Fit')
[1116]591 y = ma.masked_array(self.fitter.getfit(),
[1273]592 mask=logical_not(m))
[1088]593 self._p.plot(self.x, y)
[723]594 xlim=[min(self.x),max(self.x)]
595 self._p.axes.set_xlim(xlim)
[113]596 self._p.set_axes('xlabel',xlab)
597 self._p.set_axes('ylabel',ylab)
598 self._p.set_axes('title',tlab)
599 self._p.release()
[723]600 if (not rcParams['plotter.gui']):
601 self._p.save(filename)
602 print_log()
[113]603
[1061]604 def auto_fit(self, insitu=None, plot=False):
[113]605 """
[515]606 Return a scan where the function is applied to all rows for
607 all Beams/IFs/Pols.
[723]608
[113]609 """
610 from asap import scantable
[515]611 if not isinstance(self.data, scantable) :
[723]612 msg = "Data is not a scantable"
613 if rcParams['verbose']:
[1612]614 #print msg
615 casalog.post( msg, 'WARN' )
[723]616 return
617 else:
618 raise TypeError(msg)
[259]619 if insitu is None: insitu = rcParams['insitu']
620 if not insitu:
621 scan = self.data.copy()
622 else:
623 scan = self.data
[880]624 rows = xrange(scan.nrow())
[1446]625 # Save parameters of baseline fits as a class attribute.
626 # NOTICE: This does not reflect changes in scantable!
627 if len(rows) > 0: self.blpars=[]
[723]628 from asap import asaplog
[876]629 asaplog.push("Fitting:")
630 for r in rows:
[1603]631 out = " Scan[%d] Beam[%d] IF[%d] Pol[%d] Cycle[%d]" % (scan.getscan(r),
632 scan.getbeam(r),
633 scan.getif(r),
634 scan.getpol(r),
635 scan.getcycle(r))
[880]636 asaplog.push(out, False)
[876]637 self.x = scan._getabcissa(r)
638 self.y = scan._getspectrum(r)
[1603]639 self.mask = mask_and(self.mask, scan._getmask(r))
[876]640 self.data = None
641 self.fit()
[1603]642 x = self.get_parameters()
[1446]643 fpar = self.get_parameters()
[1061]644 if plot:
645 self.plot(residual=True)
646 x = raw_input("Accept fit ([y]/n): ")
647 if x.upper() == 'N':
[1446]648 self.blpars.append(None)
[1061]649 continue
[880]650 scan._setspectrum(self.fitter.getresidual(), r)
[1446]651 self.blpars.append(fpar)
[1061]652 if plot:
653 self._p.unmap()
654 self._p = None
[876]655 print_log()
656 return scan
[794]657
Note: See TracBrowser for help on using the repository browser.