source: branches/Release-2-fixes/python/asapfitter.py@ 3037

Last change on this file since 3037 was 667, checked in by mar637, 19 years ago

fixed colour bug. mpl doesn't understand grey80 type colours.

  • Property svn:eol-style set to native
  • Property svn:keywords set to Author Date Id Revision
File size: 14.0 KB
RevLine 
[113]1import _asap
[259]2from asap import rcParams
[113]3
4class fitter:
5 """
6 The fitting class for ASAP.
7 """
8 def _verbose(self, *args):
9 """
10 Set stdout output.
11 """
12 if type(args[0]) is bool:
13 self._vb = args[0]
14 return
15 elif len(args) == 0:
16 return self._vb
17
18 def __init__(self):
19 """
20 Create a fitter object. No state is set.
21 """
22 self.fitter = _asap.fitter()
23 self.x = None
24 self.y = None
25 self.mask = None
26 self.fitfunc = None
[515]27 self.fitfuncs = None
[113]28 self.fitted = False
29 self.data = None
[515]30 self.components = 0
31 self._fittedrow = 0
[113]32 self._p = None
33 self._vb = True
[515]34 self._selection = None
[113]35
36 def set_data(self, xdat, ydat, mask=None):
37 """
[158]38 Set the absissa and ordinate for the fit. Also set the mask
[113]39 indicationg valid points.
40 This can be used for data vectors retrieved from a scantable.
41 For scantable fitting use 'fitter.set_scan(scan, mask)'.
42 Parameters:
[158]43 xdat: the abcissa values
[113]44 ydat: the ordinate values
45 mask: an optional mask
46
47 """
48 self.fitted = False
49 self.x = xdat
50 self.y = ydat
51 if mask == None:
52 from numarray import ones
53 self.mask = ones(len(xdat))
54 else:
55 self.mask = mask
56 return
57
58 def set_scan(self, thescan=None, mask=None):
59 """
60 Set the 'data' (a scantable) of the fitter.
61 Parameters:
62 thescan: a scantable
63 mask: a msk retireved from the scantable
64 """
65 if not thescan:
66 print "Please give a correct scan"
67 self.fitted = False
68 self.data = thescan
69 if mask is None:
70 from numarray import ones
71 self.mask = ones(self.data.nchan())
72 else:
73 self.mask = mask
74 return
75
76 def set_function(self, **kwargs):
77 """
78 Set the function to be fit.
79 Parameters:
80 poly: use a polynomial of the order given
81 gauss: fit the number of gaussian specified
82 Example:
83 fitter.set_function(gauss=2) # will fit two gaussians
84 fitter.set_function(poly=3) # will fit a 3rd order polynomial
85 """
[515]86 #default poly order 0
87 n=0
[113]88 if kwargs.has_key('poly'):
89 self.fitfunc = 'poly'
90 n = kwargs.get('poly')
[515]91 self.components = [n]
[113]92 elif kwargs.has_key('gauss'):
93 n = kwargs.get('gauss')
94 self.fitfunc = 'gauss'
[515]95 self.fitfuncs = [ 'gauss' for i in range(n) ]
96 self.components = [ 3 for i in range(n) ]
97 else:
98 print "Invalid function type."
99 return
[113]100 self.fitter.setexpression(self.fitfunc,n)
101 return
102
[515]103 def fit(self, row=0):
[113]104 """
105 Execute the actual fitting process. All the state has to be set.
106 Parameters:
[526]107 row: specify the row in the scantable
[113]108 Example:
[515]109 s = scantable('myscan.asap')
110 s.set_cursor(thepol=1) # select second pol
[113]111 f = fitter()
112 f.set_scan(s)
113 f.set_function(poly=0)
[515]114 f.fit(row=0) # fit first row
[113]115 """
116 if ((self.x is None or self.y is None) and self.data is None) \
117 or self.fitfunc is None:
118 print "Fitter not yet initialised. Please set data & fit function"
119 return
120 else:
121 if self.data is not None:
[515]122 self.x = self.data._getabcissa(row)
123 self.y = self.data._getspectrum(row)
[113]124 print "Fitting:"
[259]125 vb = self.data._vb
126 self.data._vb = True
[515]127 self.selection = self.data.get_cursor()
[259]128 self.data._vb = vb
[515]129 self.fitter.setdata(self.x, self.y, self.mask)
[113]130 if self.fitfunc == 'gauss':
131 ps = self.fitter.getparameters()
132 if len(ps) == 0:
133 self.fitter.estimate()
[626]134 try:
135 self.fitter.fit()
136 except RuntimeError, msg:
137 print msg
[515]138 self._fittedrow = row
[113]139 self.fitted = True
140 return
141
[515]142 def store_fit(self):
[526]143 """
144 Store the fit parameters in the scantable.
145 """
[515]146 if self.fitted and self.data is not None:
147 pars = list(self.fitter.getparameters())
148 fixed = list(self.fitter.getfixedparameters())
149 self.data._addfit(self._fittedrow, pars, fixed,
150 self.fitfuncs, self.components)
151
152 def set_parameters(self, params, fixed=None, component=None):
[526]153 """
154 Set the parameters to be fitted.
155 Parameters:
156 params: a vector of parameters
157 fixed: a vector of which parameters are to be held fixed
158 (default is none)
159 component: in case of multiple gaussians, the index of the
160 component
161 """
[515]162 if self.fitfunc is None:
163 print "Please specify a fitting function first."
164 return
165 if self.fitfunc == "gauss" and component is not None:
166 if not self.fitted:
167 from numarray import zeros
168 pars = list(zeros(len(self.components)*3))
169 fxd = list(zeros(len(pars)))
170 else:
171 pars = list(self.fitter.getparameters())
172 fxd = list(self.fitter.getfixedparameters())
173 i = 3*component
174 pars[i:i+3] = params
175 fxd[i:i+3] = fixed
176 params = pars
177 fixed = fxd
[113]178 self.fitter.setparameters(params)
179 if fixed is not None:
180 self.fitter.setfixedparameters(fixed)
181 return
[515]182
183 def set_gauss_parameters(self, peak, centre, fhwm,
184 peakfixed=False, centerfixed=False,
185 fhwmfixed=False,
186 component=0):
[113]187 """
[515]188 Set the Parameters of a 'Gaussian' component, set with set_function.
189 Parameters:
190 peak, centre, fhwm: The gaussian parameters
191 peakfixed,
192 centerfixed,
193 fhwmfixed: Optional parameters to indicate if
194 the paramters should be held fixed during
195 the fitting process. The default is to keep
196 all parameters flexible.
[526]197 component: The number of the component (Default is the
198 component 0)
[515]199 """
200 if self.fitfunc != "gauss":
201 print "Function only operates on Gaussian components."
202 return
203 if 0 <= component < len(self.components):
204 self.set_parameters([peak, centre, fhwm],
205 [peakfixed, centerfixed, fhwmfixed],
206 component)
207 else:
208 print "Please select a valid component."
209 return
210
211 def get_parameters(self, component=None):
212 """
[113]213 Return the fit paramters.
[526]214 Parameters:
215 component: get the parameters for the specified component
216 only, default is all components
[113]217 """
218 if not self.fitted:
219 print "Not yet fitted."
220 pars = list(self.fitter.getparameters())
221 fixed = list(self.fitter.getfixedparameters())
[515]222 if component is not None:
223 if self.fitfunc == "gauss":
224 i = 3*component
225 cpars = pars[i:i+3]
226 cfixed = fixed[i:i+3]
227 else:
228 cpars = pars
229 cfixed = fixed
230 else:
231 cpars = pars
232 cfixed = fixed
233 fpars = self._format_pars(cpars, cfixed)
[113]234 if self._vb:
[515]235 print fpars
236 return cpars, cfixed, fpars
[113]237
[515]238 def _format_pars(self, pars, fixed):
[113]239 out = ''
240 if self.fitfunc == 'poly':
241 c = 0
[515]242 for i in range(len(pars)):
243 fix = ""
244 if fixed[i]: fix = "(fixed)"
245 out += ' p%d%s= %3.3f,' % (c,fix,pars[i])
[113]246 c+=1
[515]247 out = out[:-1] # remove trailing ','
[113]248 elif self.fitfunc == 'gauss':
249 i = 0
250 c = 0
[515]251 aunit = ''
252 ounit = ''
[113]253 if self.data:
[515]254 aunit = self.data.get_unit()
255 ounit = self.data.get_fluxunit()
[113]256 while i < len(pars):
[515]257 out += ' %d: peak = %3.3f %s , centre = %3.3f %s, FWHM = %3.3f %s \n' % (c,pars[i],ounit,pars[i+1],aunit,pars[i+2],aunit)
[113]258 c+=1
259 i+=3
260 return out
261
262 def get_estimate(self):
263 """
[515]264 Return the parameter estimates (for non-linear functions).
[113]265 """
266 pars = self.fitter.getestimate()
267 if self._vb:
268 print self._format_pars(pars)
269 return pars
270
271
272 def get_residual(self):
273 """
274 Return the residual of the fit.
275 """
276 if not self.fitted:
277 print "Not yet fitted."
278 return self.fitter.getresidual()
279
280 def get_chi2(self):
281 """
282 Return chi^2.
283 """
284 if not self.fitted:
285 print "Not yet fitted."
286 ch2 = self.fitter.getchi2()
287 if self._vb:
288 print 'Chi^2 = %3.3f' % (ch2)
289 return ch2
290
291 def get_fit(self):
292 """
293 Return the fitted ordinate values.
294 """
295 if not self.fitted:
296 print "Not yet fitted."
297 return self.fitter.getfit()
298
299 def commit(self):
300 """
[526]301 Return a new scan where the fits have been commited (subtracted)
[113]302 """
303 if not self.fitted:
304 print "Not yet fitted."
305 if self.data is not scantable:
306 print "Only works with scantables"
307 return
308 scan = self.data.copy()
[259]309 scan._setspectrum(self.fitter.getresidual())
[113]310
[526]311 def plot(self, residual=False, components=None, plotparms=False):
[113]312 """
313 Plot the last fit.
314 Parameters:
315 residual: an optional parameter indicating if the residual
316 should be plotted (default 'False')
[526]317 components: a list of components to plot, e.g [0,1],
318 -1 plots the total fit. Default is to only
319 plot the total fit.
320 plotparms: Inidicates if the parameter values should be present
321 on the plot
[113]322 """
323 if not self.fitted:
324 return
325 if not self._p:
326 from asap.asaplot import ASAPlot
327 self._p = ASAPlot()
[298]328 if self._p.is_dead:
[190]329 from asap.asaplot import ASAPlot
330 self._p = ASAPlot()
[113]331 self._p.clear()
[515]332 self._p.set_panels()
[652]333 self._p.palette(0)
[113]334 tlab = 'Spectrum'
[515]335 xlab = 'Abcissa'
336 m = ()
[113]337 if self.data:
[515]338 tlab = self.data._getsourcename(self._fittedrow)
339 xlab = self.data._getabcissalabel(self._fittedrow)
340 m = self.data._getmask(self._fittedrow)
[626]341 ylab = self.data._get_ordinate_label()
[515]342
[667]343 colours = ["#777777","#bbbbbb","red","orange","purple","green","magenta", "cyan"]
[652]344 self._p.palette(0,colours)
[515]345 self._p.set_line(label='Spectrum')
[113]346 self._p.plot(self.x, self.y, m)
347 if residual:
[652]348 self._p.palette(1)
[515]349 self._p.set_line(label='Residual')
[113]350 self._p.plot(self.x, self.get_residual(), m)
[652]351 self._p.palette(2)
[515]352 if components is not None:
353 cs = components
354 if isinstance(components,int): cs = [components]
[526]355 if plotparms:
356 self._p.text(0.15,0.15,str(self.get_parameters()[2]),size=8)
[515]357 n = len(self.components)
[652]358 self._p.palette(3)
[515]359 for c in cs:
360 if 0 <= c < n:
361 lab = self.fitfuncs[c]+str(c)
362 self._p.set_line(label=lab)
363 self._p.plot(self.x, self.fitter.evaluate(c), m)
364 elif c == -1:
[652]365 self._p.palette(2)
[515]366 self._p.set_line(label="Total Fit")
367 self._p.plot(self.x, self.get_fit(), m)
368 else:
[652]369 self._p.palette(2)
[515]370 self._p.set_line(label='Fit')
371 self._p.plot(self.x, self.get_fit(), m)
[113]372 self._p.set_axes('xlabel',xlab)
373 self._p.set_axes('ylabel',ylab)
374 self._p.set_axes('title',tlab)
375 self._p.release()
376
[259]377 def auto_fit(self, insitu=None):
[113]378 """
[515]379 Return a scan where the function is applied to all rows for
380 all Beams/IFs/Pols.
[113]381
382 """
383 from asap import scantable
[515]384 if not isinstance(self.data, scantable) :
[113]385 print "Only works with scantables"
386 return
[259]387 if insitu is None: insitu = rcParams['insitu']
388 if not insitu:
389 scan = self.data.copy()
390 else:
391 scan = self.data
392 vb = scan._vb
393 scan._vb = False
394 sel = scan.get_cursor()
[159]395 rows = range(scan.nrow())
[113]396 for i in range(scan.nbeam()):
397 scan.setbeam(i)
398 for j in range(scan.nif()):
399 scan.setif(j)
400 for k in range(scan.npol()):
401 scan.setpol(k)
402 if self._vb:
403 print "Fitting:"
404 print 'Beam[%d], IF[%d], Pol[%d]' % (i,j,k)
[159]405 for iRow in rows:
[259]406 self.x = scan._getabcissa(iRow)
407 self.y = scan._getspectrum(iRow)
[159]408 self.data = None
409 self.fit()
[113]410 x = self.get_parameters()
[259]411 scan._setspectrum(self.fitter.getresidual(),iRow)
412 scan.set_cursor(sel[0],sel[1],sel[2])
413 scan._vb = vb
414 if not insitu:
415 return scan
Note: See TracBrowser for help on using the repository browser.