source: trunk/python/asapfitter.py@ 1088

Last change on this file since 1088 was 1088, checked in by mar637, 19 years ago

use MA instead of spectrum and mask for plotting. THIS isn't tested yet. printing out errors for poly coeffs. need to do gauss

  • Property svn:eol-style set to native
  • Property svn:keywords set to Author Date Id Revision
File size: 20.1 KB
RevLine 
[113]1import _asap
[259]2from asap import rcParams
[723]3from asap import print_log
[113]4
5class fitter:
6 """
7 The fitting class for ASAP.
8 """
[723]9
[113]10 def __init__(self):
11 """
12 Create a fitter object. No state is set.
13 """
14 self.fitter = _asap.fitter()
15 self.x = None
16 self.y = None
17 self.mask = None
18 self.fitfunc = None
[515]19 self.fitfuncs = None
[113]20 self.fitted = False
21 self.data = None
[515]22 self.components = 0
23 self._fittedrow = 0
[113]24 self._p = None
[515]25 self._selection = None
[113]26
27 def set_data(self, xdat, ydat, mask=None):
28 """
[158]29 Set the absissa and ordinate for the fit. Also set the mask
[113]30 indicationg valid points.
31 This can be used for data vectors retrieved from a scantable.
32 For scantable fitting use 'fitter.set_scan(scan, mask)'.
33 Parameters:
[158]34 xdat: the abcissa values
[113]35 ydat: the ordinate values
36 mask: an optional mask
[723]37
[113]38 """
39 self.fitted = False
40 self.x = xdat
41 self.y = ydat
42 if mask == None:
43 from numarray import ones
44 self.mask = ones(len(xdat))
45 else:
46 self.mask = mask
47 return
48
49 def set_scan(self, thescan=None, mask=None):
50 """
51 Set the 'data' (a scantable) of the fitter.
52 Parameters:
53 thescan: a scantable
54 mask: a msk retireved from the scantable
55 """
56 if not thescan:
[723]57 msg = "Please give a correct scan"
58 if rcParams['verbose']:
59 print msg
60 return
61 else:
62 raise TypeError(msg)
[113]63 self.fitted = False
64 self.data = thescan
[1075]65 self.mask = None
[113]66 if mask is None:
67 from numarray import ones
68 self.mask = ones(self.data.nchan())
69 else:
70 self.mask = mask
71 return
72
73 def set_function(self, **kwargs):
74 """
75 Set the function to be fit.
76 Parameters:
77 poly: use a polynomial of the order given
78 gauss: fit the number of gaussian specified
79 Example:
80 fitter.set_function(gauss=2) # will fit two gaussians
81 fitter.set_function(poly=3) # will fit a 3rd order polynomial
82 """
[723]83 #default poly order 0
[515]84 n=0
[113]85 if kwargs.has_key('poly'):
86 self.fitfunc = 'poly'
87 n = kwargs.get('poly')
[515]88 self.components = [n]
[113]89 elif kwargs.has_key('gauss'):
90 n = kwargs.get('gauss')
91 self.fitfunc = 'gauss'
[515]92 self.fitfuncs = [ 'gauss' for i in range(n) ]
93 self.components = [ 3 for i in range(n) ]
94 else:
[723]95 msg = "Invalid function type."
96 if rcParams['verbose']:
97 print msg
98 return
99 else:
100 raise TypeError(msg)
101
[113]102 self.fitter.setexpression(self.fitfunc,n)
103 return
[723]104
[1075]105 def fit(self, row=0, estimate=False):
[113]106 """
107 Execute the actual fitting process. All the state has to be set.
108 Parameters:
[1075]109 row: specify the row in the scantable
110 estimate: auto-compute an initial parameter set (default False)
111 This can be used to compute estimates even if fit was
112 called before.
[113]113 Example:
[515]114 s = scantable('myscan.asap')
115 s.set_cursor(thepol=1) # select second pol
[113]116 f = fitter()
117 f.set_scan(s)
118 f.set_function(poly=0)
[723]119 f.fit(row=0) # fit first row
[113]120 """
121 if ((self.x is None or self.y is None) and self.data is None) \
122 or self.fitfunc is None:
[723]123 msg = "Fitter not yet initialised. Please set data & fit function"
124 if rcParams['verbose']:
125 print msg
126 return
127 else:
128 raise RuntimeError(msg)
129
[113]130 else:
131 if self.data is not None:
[515]132 self.x = self.data._getabcissa(row)
133 self.y = self.data._getspectrum(row)
[723]134 from asap import asaplog
135 asaplog.push("Fitting:")
[943]136 i = row
137 out = "Scan[%d] Beam[%d] IF[%d] Pol[%d] Cycle[%d]" % (self.data.getscan(i),self.data.getbeam(i),self.data.getif(i),self.data.getpol(i), self.data.getcycle(i))
[1075]138 asaplog.push(out,False)
[515]139 self.fitter.setdata(self.x, self.y, self.mask)
[113]140 if self.fitfunc == 'gauss':
141 ps = self.fitter.getparameters()
[1075]142 if len(ps) == 0 or estimate:
[113]143 self.fitter.estimate()
[626]144 try:
[1075]145 converged = self.fitter.fit()
146 if not converged:
147 raise RuntimeError,"Fit didn't converge."
[626]148 except RuntimeError, msg:
[723]149 if rcParams['verbose']:
150 print msg
151 else:
152 raise
[515]153 self._fittedrow = row
[113]154 self.fitted = True
[723]155 print_log()
[113]156 return
157
[515]158 def store_fit(self):
[526]159 """
160 Store the fit parameters in the scantable.
161 """
[515]162 if self.fitted and self.data is not None:
163 pars = list(self.fitter.getparameters())
164 fixed = list(self.fitter.getfixedparameters())
[975]165 from asap.asapfit import asapfit
166 fit = asapfit()
167 fit.setparameters(pars)
168 fit.setfixedparameters(fixed)
169 fit.setfunctions(self.fitfuncs)
170 fit.setcomponents(self.components)
171 fit.setframeinfo(self.data._getcoordinfo())
172 self.data._addfit(fit,self._fittedrow)
[515]173
[1017]174 #def set_parameters(self, params, fixed=None, component=None):
175 def set_parameters(self,*args,**kwargs):
[526]176 """
177 Set the parameters to be fitted.
178 Parameters:
179 params: a vector of parameters
180 fixed: a vector of which parameters are to be held fixed
181 (default is none)
182 component: in case of multiple gaussians, the index of the
183 component
[1017]184 """
185 component = None
186 fixed = None
187 params = None
[1031]188
[1017]189 if len(args) and isinstance(args[0],dict):
190 kwargs = args[0]
191 if kwargs.has_key("fixed"): fixed = kwargs["fixed"]
192 if kwargs.has_key("params"): params = kwargs["params"]
193 if len(args) == 2 and isinstance(args[1], int):
194 component = args[1]
[515]195 if self.fitfunc is None:
[723]196 msg = "Please specify a fitting function first."
197 if rcParams['verbose']:
198 print msg
199 return
200 else:
201 raise RuntimeError(msg)
[515]202 if self.fitfunc == "gauss" and component is not None:
[1017]203 if not self.fitted and sum(self.fitter.getparameters()) == 0:
[515]204 from numarray import zeros
205 pars = list(zeros(len(self.components)*3))
206 fxd = list(zeros(len(pars)))
207 else:
[723]208 pars = list(self.fitter.getparameters())
[515]209 fxd = list(self.fitter.getfixedparameters())
210 i = 3*component
211 pars[i:i+3] = params
212 fxd[i:i+3] = fixed
213 params = pars
[723]214 fixed = fxd
[113]215 self.fitter.setparameters(params)
216 if fixed is not None:
217 self.fitter.setfixedparameters(fixed)
[723]218 print_log()
[113]219 return
[515]220
221 def set_gauss_parameters(self, peak, centre, fhwm,
[1017]222 peakfixed=0, centerfixed=0,
223 fhwmfixed=0,
[515]224 component=0):
[113]225 """
[515]226 Set the Parameters of a 'Gaussian' component, set with set_function.
227 Parameters:
228 peak, centre, fhwm: The gaussian parameters
229 peakfixed,
230 centerfixed,
231 fhwmfixed: Optional parameters to indicate if
232 the paramters should be held fixed during
233 the fitting process. The default is to keep
234 all parameters flexible.
[526]235 component: The number of the component (Default is the
236 component 0)
[515]237 """
238 if self.fitfunc != "gauss":
[723]239 msg = "Function only operates on Gaussian components."
240 if rcParams['verbose']:
241 print msg
242 return
243 else:
244 raise ValueError(msg)
[515]245 if 0 <= component < len(self.components):
[1017]246 d = {'params':[peak, centre, fhwm],
247 'fixed':[peakfixed, centerfixed, fhwmfixed]}
248 self.set_parameters(d, component)
[515]249 else:
[723]250 msg = "Please select a valid component."
251 if rcParams['verbose']:
252 print msg
253 return
254 else:
255 raise ValueError(msg)
256
[975]257 def get_area(self, component=None):
258 """
259 Return the area under the fitted gaussian component.
260 Parameters:
261 component: the gaussian component selection,
262 default (None) is the sum of all components
263 Note:
264 This will only work for gaussian fits.
265 """
266 if not self.fitted: return
267 if self.fitfunc == "gauss":
268 pars = list(self.fitter.getparameters())
269 from math import log,pi,sqrt
270 fac = sqrt(pi/log(16.0))
271 areas = []
272 for i in range(len(self.components)):
273 j = i*3
274 cpars = pars[j:j+3]
275 areas.append(fac * cpars[0] * cpars[2])
276 else:
277 return None
278 if component is not None:
279 return areas[component]
280 else:
281 return sum(areas)
282
[1075]283 def get_errors(self, component=None):
[515]284 """
[1075]285 Return the errors in the parameters.
286 Parameters:
287 component: get the errors for the specified component
288 only, default is all components
289 """
290 if not self.fitted:
291 msg = "Not yet fitted."
292 if rcParams['verbose']:
293 print msg
294 return
295 else:
296 raise RuntimeError(msg)
297 errs = list(self.fitter.geterrors())
298 cerrs = errs
299 if component is not None:
300 if self.fitfunc == "gauss":
301 i = 3*component
302 if i < len(errs):
303 cerrs = errs[i:i+3]
304 return cerrs
305
306 def get_parameters(self, component=None, errors=False):
307 """
[113]308 Return the fit paramters.
[526]309 Parameters:
310 component: get the parameters for the specified component
311 only, default is all components
[113]312 """
313 if not self.fitted:
[723]314 msg = "Not yet fitted."
315 if rcParams['verbose']:
316 print msg
317 return
318 else:
319 raise RuntimeError(msg)
[113]320 pars = list(self.fitter.getparameters())
321 fixed = list(self.fitter.getfixedparameters())
[1075]322 errs = list(self.fitter.geterrors())
[1039]323 area = []
[723]324 if component is not None:
[515]325 if self.fitfunc == "gauss":
326 i = 3*component
327 cpars = pars[i:i+3]
328 cfixed = fixed[i:i+3]
[1075]329 cerrs = errs[i:i+3]
[1039]330 a = self.get_area(component)
331 area = [a for i in range(3)]
[515]332 else:
333 cpars = pars
[723]334 cfixed = fixed
[1075]335 cerrs = errs
[515]336 else:
337 cpars = pars
338 cfixed = fixed
[1075]339 cerrs = errs
[1039]340 if self.fitfunc == "gauss":
341 for c in range(len(self.components)):
342 a = self.get_area(c)
343 area += [a for i in range(3)]
[1088]344 fpars = self._format_pars(cpars, cfixed, errors and cerrs, area)
[723]345 if rcParams['verbose']:
[515]346 print fpars
[1075]347 return {'params':cpars, 'fixed':cfixed, 'formatted': fpars,
348 'errors':cerrs}
[723]349
[1075]350 def _format_pars(self, pars, fixed, errors, area):
[113]351 out = ''
352 if self.fitfunc == 'poly':
353 c = 0
[515]354 for i in range(len(pars)):
355 fix = ""
356 if fixed[i]: fix = "(fixed)"
[1088]357 if errors :
358 out += ' p%d%s= %3.6f (%1.6f),' % (c,fix,pars[i], errors[i])
359 else:
360 out += ' p%d%s= %3.6f,' % (c,fix,pars[i])
[113]361 c+=1
[515]362 out = out[:-1] # remove trailing ','
[113]363 elif self.fitfunc == 'gauss':
364 i = 0
365 c = 0
[515]366 aunit = ''
367 ounit = ''
[113]368 if self.data:
[515]369 aunit = self.data.get_unit()
370 ounit = self.data.get_fluxunit()
[113]371 while i < len(pars):
[1039]372 if len(area):
373 out += ' %2d: peak = %3.3f %s , centre = %3.3f %s, FWHM = %3.3f %s\n area = %3.3f %s %s\n' % (c,pars[i],ounit,pars[i+1],aunit,pars[i+2],aunit, area[i],ounit,aunit)
[1017]374 else:
375 out += ' %2d: peak = %3.3f %s , centre = %3.3f %s, FWHM = %3.3f %s\n' % (c,pars[i],ounit,pars[i+1],aunit,pars[i+2],aunit,ounit,aunit)
[113]376 c+=1
377 i+=3
378 return out
[723]379
[113]380 def get_estimate(self):
381 """
[515]382 Return the parameter estimates (for non-linear functions).
[113]383 """
384 pars = self.fitter.getestimate()
[943]385 fixed = self.fitter.getfixedparameters()
[723]386 if rcParams['verbose']:
[1017]387 print self._format_pars(pars,fixed,None)
[113]388 return pars
389
390 def get_residual(self):
391 """
392 Return the residual of the fit.
393 """
394 if not self.fitted:
[723]395 msg = "Not yet fitted."
396 if rcParams['verbose']:
397 print msg
398 return
399 else:
400 raise RuntimeError(msg)
[113]401 return self.fitter.getresidual()
402
403 def get_chi2(self):
404 """
405 Return chi^2.
406 """
407 if not self.fitted:
[723]408 msg = "Not yet fitted."
409 if rcParams['verbose']:
410 print msg
411 return
412 else:
413 raise RuntimeError(msg)
[113]414 ch2 = self.fitter.getchi2()
[723]415 if rcParams['verbose']:
[113]416 print 'Chi^2 = %3.3f' % (ch2)
[723]417 return ch2
[113]418
419 def get_fit(self):
420 """
421 Return the fitted ordinate values.
422 """
423 if not self.fitted:
[723]424 msg = "Not yet fitted."
425 if rcParams['verbose']:
426 print msg
427 return
428 else:
429 raise RuntimeError(msg)
[113]430 return self.fitter.getfit()
431
432 def commit(self):
433 """
[526]434 Return a new scan where the fits have been commited (subtracted)
[113]435 """
436 if not self.fitted:
437 print "Not yet fitted."
[723]438 msg = "Not yet fitted."
439 if rcParams['verbose']:
440 print msg
441 return
442 else:
443 raise RuntimeError(msg)
[975]444 from asap import scantable
445 if not isinstance(self.data, scantable):
[723]446 msg = "Not a scantable"
447 if rcParams['verbose']:
448 print msg
449 return
450 else:
451 raise TypeError(msg)
[113]452 scan = self.data.copy()
[259]453 scan._setspectrum(self.fitter.getresidual())
[723]454 print_log()
[113]455
[723]456 def plot(self, residual=False, components=None, plotparms=False, filename=None):
[113]457 """
458 Plot the last fit.
459 Parameters:
460 residual: an optional parameter indicating if the residual
461 should be plotted (default 'False')
[526]462 components: a list of components to plot, e.g [0,1],
463 -1 plots the total fit. Default is to only
464 plot the total fit.
465 plotparms: Inidicates if the parameter values should be present
466 on the plot
[113]467 """
468 if not self.fitted:
469 return
[723]470 if not self._p or self._p.is_dead:
471 if rcParams['plotter.gui']:
472 from asap.asaplotgui import asaplotgui as asaplot
473 else:
474 from asap.asaplot import asaplot
475 self._p = asaplot()
476 self._p.hold()
[113]477 self._p.clear()
[515]478 self._p.set_panels()
[652]479 self._p.palette(0)
[113]480 tlab = 'Spectrum'
[723]481 xlab = 'Abcissa'
[1017]482 ylab = 'Ordinate'
483 m = None
[113]484 if self.data:
[515]485 tlab = self.data._getsourcename(self._fittedrow)
486 xlab = self.data._getabcissalabel(self._fittedrow)
487 m = self.data._getmask(self._fittedrow)
[626]488 ylab = self.data._get_ordinate_label()
[515]489
[1075]490 colours = ["#777777","#dddddd","red","orange","purple","green","magenta", "cyan"]
[652]491 self._p.palette(0,colours)
[515]492 self._p.set_line(label='Spectrum')
[1088]493 from matplotlib.numerix import ma,logical_not,array
494 y = ma.MA.MaskedArray(self.y,mask=logical_not(array(m,copy=0)),copy=0)
495 self._p.plot(self.x, y)
[113]496 if residual:
[652]497 self._p.palette(1)
[515]498 self._p.set_line(label='Residual')
[1088]499 y = ma.MA.MaskedArray(self.get_residual(),
500 mask=logical_not(array(m,copy=0)),copy=0)
501 self._p.plot(self.x, y)
[652]502 self._p.palette(2)
[515]503 if components is not None:
504 cs = components
505 if isinstance(components,int): cs = [components]
[526]506 if plotparms:
[1031]507 self._p.text(0.15,0.15,str(self.get_parameters()['formatted']),size=8)
[515]508 n = len(self.components)
[652]509 self._p.palette(3)
[515]510 for c in cs:
511 if 0 <= c < n:
512 lab = self.fitfuncs[c]+str(c)
513 self._p.set_line(label=lab)
[1088]514 y = ma.MA.MaskedArray(self.fitter.evaluate(c),
515 mask=logical_not(array(m,copy=0)),
516 copy=0)
517
518 self._p.plot(self.x, y)
[515]519 elif c == -1:
[652]520 self._p.palette(2)
[515]521 self._p.set_line(label="Total Fit")
[1088]522 y = ma.MA.MaskedArray(self.fitter.get_fit(),
523 mask=logical_not(array(m,copy=0)),
524 copy=0)
525 self._p.plot(self.x, y)
[515]526 else:
[652]527 self._p.palette(2)
[515]528 self._p.set_line(label='Fit')
[1088]529 y = ma.MA.MaskedArray(self.fitter.get_fit(),
530 mask=logical_not(array(m,copy=0)),
531 copy=0)
532 self._p.plot(self.x, y)
[723]533 xlim=[min(self.x),max(self.x)]
534 self._p.axes.set_xlim(xlim)
[113]535 self._p.set_axes('xlabel',xlab)
536 self._p.set_axes('ylabel',ylab)
537 self._p.set_axes('title',tlab)
538 self._p.release()
[723]539 if (not rcParams['plotter.gui']):
540 self._p.save(filename)
541 print_log()
[113]542
[1061]543 def auto_fit(self, insitu=None, plot=False):
[113]544 """
[515]545 Return a scan where the function is applied to all rows for
546 all Beams/IFs/Pols.
[723]547
[113]548 """
549 from asap import scantable
[515]550 if not isinstance(self.data, scantable) :
[723]551 msg = "Data is not a scantable"
552 if rcParams['verbose']:
553 print msg
554 return
555 else:
556 raise TypeError(msg)
[259]557 if insitu is None: insitu = rcParams['insitu']
558 if not insitu:
559 scan = self.data.copy()
560 else:
561 scan = self.data
[880]562 rows = xrange(scan.nrow())
[723]563 from asap import asaplog
[876]564 asaplog.push("Fitting:")
565 for r in rows:
[1031]566 out = " Scan[%d] Beam[%d] IF[%d] Pol[%d] Cycle[%d]" % (scan.getscan(r),scan.getbeam(r),scan.getif(r),scan.getpol(r), scan.getcycle(r))
[880]567 asaplog.push(out, False)
[876]568 self.x = scan._getabcissa(r)
569 self.y = scan._getspectrum(r)
570 self.data = None
571 self.fit()
572 x = self.get_parameters()
[1061]573 if plot:
574 self.plot(residual=True)
575 x = raw_input("Accept fit ([y]/n): ")
576 if x.upper() == 'N':
577 continue
[880]578 scan._setspectrum(self.fitter.getresidual(), r)
[1061]579 if plot:
580 self._p.unmap()
581 self._p = None
[876]582 print_log()
583 return scan
[794]584
Note: See TracBrowser for help on using the repository browser.