source: branches/Release-2-fixes/python/asapfitter.py@ 2301

Last change on this file since 2301 was 667, checked in by mar637, 19 years ago

fixed colour bug. mpl doesn't understand grey80 type colours.

  • Property svn:eol-style set to native
  • Property svn:keywords set to Author Date Id Revision
File size: 14.0 KB
Line 
1import _asap
2from asap import rcParams
3
4class fitter:
5 """
6 The fitting class for ASAP.
7 """
8 def _verbose(self, *args):
9 """
10 Set stdout output.
11 """
12 if type(args[0]) is bool:
13 self._vb = args[0]
14 return
15 elif len(args) == 0:
16 return self._vb
17
18 def __init__(self):
19 """
20 Create a fitter object. No state is set.
21 """
22 self.fitter = _asap.fitter()
23 self.x = None
24 self.y = None
25 self.mask = None
26 self.fitfunc = None
27 self.fitfuncs = None
28 self.fitted = False
29 self.data = None
30 self.components = 0
31 self._fittedrow = 0
32 self._p = None
33 self._vb = True
34 self._selection = None
35
36 def set_data(self, xdat, ydat, mask=None):
37 """
38 Set the absissa and ordinate for the fit. Also set the mask
39 indicationg valid points.
40 This can be used for data vectors retrieved from a scantable.
41 For scantable fitting use 'fitter.set_scan(scan, mask)'.
42 Parameters:
43 xdat: the abcissa values
44 ydat: the ordinate values
45 mask: an optional mask
46
47 """
48 self.fitted = False
49 self.x = xdat
50 self.y = ydat
51 if mask == None:
52 from numarray import ones
53 self.mask = ones(len(xdat))
54 else:
55 self.mask = mask
56 return
57
58 def set_scan(self, thescan=None, mask=None):
59 """
60 Set the 'data' (a scantable) of the fitter.
61 Parameters:
62 thescan: a scantable
63 mask: a msk retireved from the scantable
64 """
65 if not thescan:
66 print "Please give a correct scan"
67 self.fitted = False
68 self.data = thescan
69 if mask is None:
70 from numarray import ones
71 self.mask = ones(self.data.nchan())
72 else:
73 self.mask = mask
74 return
75
76 def set_function(self, **kwargs):
77 """
78 Set the function to be fit.
79 Parameters:
80 poly: use a polynomial of the order given
81 gauss: fit the number of gaussian specified
82 Example:
83 fitter.set_function(gauss=2) # will fit two gaussians
84 fitter.set_function(poly=3) # will fit a 3rd order polynomial
85 """
86 #default poly order 0
87 n=0
88 if kwargs.has_key('poly'):
89 self.fitfunc = 'poly'
90 n = kwargs.get('poly')
91 self.components = [n]
92 elif kwargs.has_key('gauss'):
93 n = kwargs.get('gauss')
94 self.fitfunc = 'gauss'
95 self.fitfuncs = [ 'gauss' for i in range(n) ]
96 self.components = [ 3 for i in range(n) ]
97 else:
98 print "Invalid function type."
99 return
100 self.fitter.setexpression(self.fitfunc,n)
101 return
102
103 def fit(self, row=0):
104 """
105 Execute the actual fitting process. All the state has to be set.
106 Parameters:
107 row: specify the row in the scantable
108 Example:
109 s = scantable('myscan.asap')
110 s.set_cursor(thepol=1) # select second pol
111 f = fitter()
112 f.set_scan(s)
113 f.set_function(poly=0)
114 f.fit(row=0) # fit first row
115 """
116 if ((self.x is None or self.y is None) and self.data is None) \
117 or self.fitfunc is None:
118 print "Fitter not yet initialised. Please set data & fit function"
119 return
120 else:
121 if self.data is not None:
122 self.x = self.data._getabcissa(row)
123 self.y = self.data._getspectrum(row)
124 print "Fitting:"
125 vb = self.data._vb
126 self.data._vb = True
127 self.selection = self.data.get_cursor()
128 self.data._vb = vb
129 self.fitter.setdata(self.x, self.y, self.mask)
130 if self.fitfunc == 'gauss':
131 ps = self.fitter.getparameters()
132 if len(ps) == 0:
133 self.fitter.estimate()
134 try:
135 self.fitter.fit()
136 except RuntimeError, msg:
137 print msg
138 self._fittedrow = row
139 self.fitted = True
140 return
141
142 def store_fit(self):
143 """
144 Store the fit parameters in the scantable.
145 """
146 if self.fitted and self.data is not None:
147 pars = list(self.fitter.getparameters())
148 fixed = list(self.fitter.getfixedparameters())
149 self.data._addfit(self._fittedrow, pars, fixed,
150 self.fitfuncs, self.components)
151
152 def set_parameters(self, params, fixed=None, component=None):
153 """
154 Set the parameters to be fitted.
155 Parameters:
156 params: a vector of parameters
157 fixed: a vector of which parameters are to be held fixed
158 (default is none)
159 component: in case of multiple gaussians, the index of the
160 component
161 """
162 if self.fitfunc is None:
163 print "Please specify a fitting function first."
164 return
165 if self.fitfunc == "gauss" and component is not None:
166 if not self.fitted:
167 from numarray import zeros
168 pars = list(zeros(len(self.components)*3))
169 fxd = list(zeros(len(pars)))
170 else:
171 pars = list(self.fitter.getparameters())
172 fxd = list(self.fitter.getfixedparameters())
173 i = 3*component
174 pars[i:i+3] = params
175 fxd[i:i+3] = fixed
176 params = pars
177 fixed = fxd
178 self.fitter.setparameters(params)
179 if fixed is not None:
180 self.fitter.setfixedparameters(fixed)
181 return
182
183 def set_gauss_parameters(self, peak, centre, fhwm,
184 peakfixed=False, centerfixed=False,
185 fhwmfixed=False,
186 component=0):
187 """
188 Set the Parameters of a 'Gaussian' component, set with set_function.
189 Parameters:
190 peak, centre, fhwm: The gaussian parameters
191 peakfixed,
192 centerfixed,
193 fhwmfixed: Optional parameters to indicate if
194 the paramters should be held fixed during
195 the fitting process. The default is to keep
196 all parameters flexible.
197 component: The number of the component (Default is the
198 component 0)
199 """
200 if self.fitfunc != "gauss":
201 print "Function only operates on Gaussian components."
202 return
203 if 0 <= component < len(self.components):
204 self.set_parameters([peak, centre, fhwm],
205 [peakfixed, centerfixed, fhwmfixed],
206 component)
207 else:
208 print "Please select a valid component."
209 return
210
211 def get_parameters(self, component=None):
212 """
213 Return the fit paramters.
214 Parameters:
215 component: get the parameters for the specified component
216 only, default is all components
217 """
218 if not self.fitted:
219 print "Not yet fitted."
220 pars = list(self.fitter.getparameters())
221 fixed = list(self.fitter.getfixedparameters())
222 if component is not None:
223 if self.fitfunc == "gauss":
224 i = 3*component
225 cpars = pars[i:i+3]
226 cfixed = fixed[i:i+3]
227 else:
228 cpars = pars
229 cfixed = fixed
230 else:
231 cpars = pars
232 cfixed = fixed
233 fpars = self._format_pars(cpars, cfixed)
234 if self._vb:
235 print fpars
236 return cpars, cfixed, fpars
237
238 def _format_pars(self, pars, fixed):
239 out = ''
240 if self.fitfunc == 'poly':
241 c = 0
242 for i in range(len(pars)):
243 fix = ""
244 if fixed[i]: fix = "(fixed)"
245 out += ' p%d%s= %3.3f,' % (c,fix,pars[i])
246 c+=1
247 out = out[:-1] # remove trailing ','
248 elif self.fitfunc == 'gauss':
249 i = 0
250 c = 0
251 aunit = ''
252 ounit = ''
253 if self.data:
254 aunit = self.data.get_unit()
255 ounit = self.data.get_fluxunit()
256 while i < len(pars):
257 out += ' %d: peak = %3.3f %s , centre = %3.3f %s, FWHM = %3.3f %s \n' % (c,pars[i],ounit,pars[i+1],aunit,pars[i+2],aunit)
258 c+=1
259 i+=3
260 return out
261
262 def get_estimate(self):
263 """
264 Return the parameter estimates (for non-linear functions).
265 """
266 pars = self.fitter.getestimate()
267 if self._vb:
268 print self._format_pars(pars)
269 return pars
270
271
272 def get_residual(self):
273 """
274 Return the residual of the fit.
275 """
276 if not self.fitted:
277 print "Not yet fitted."
278 return self.fitter.getresidual()
279
280 def get_chi2(self):
281 """
282 Return chi^2.
283 """
284 if not self.fitted:
285 print "Not yet fitted."
286 ch2 = self.fitter.getchi2()
287 if self._vb:
288 print 'Chi^2 = %3.3f' % (ch2)
289 return ch2
290
291 def get_fit(self):
292 """
293 Return the fitted ordinate values.
294 """
295 if not self.fitted:
296 print "Not yet fitted."
297 return self.fitter.getfit()
298
299 def commit(self):
300 """
301 Return a new scan where the fits have been commited (subtracted)
302 """
303 if not self.fitted:
304 print "Not yet fitted."
305 if self.data is not scantable:
306 print "Only works with scantables"
307 return
308 scan = self.data.copy()
309 scan._setspectrum(self.fitter.getresidual())
310
311 def plot(self, residual=False, components=None, plotparms=False):
312 """
313 Plot the last fit.
314 Parameters:
315 residual: an optional parameter indicating if the residual
316 should be plotted (default 'False')
317 components: a list of components to plot, e.g [0,1],
318 -1 plots the total fit. Default is to only
319 plot the total fit.
320 plotparms: Inidicates if the parameter values should be present
321 on the plot
322 """
323 if not self.fitted:
324 return
325 if not self._p:
326 from asap.asaplot import ASAPlot
327 self._p = ASAPlot()
328 if self._p.is_dead:
329 from asap.asaplot import ASAPlot
330 self._p = ASAPlot()
331 self._p.clear()
332 self._p.set_panels()
333 self._p.palette(0)
334 tlab = 'Spectrum'
335 xlab = 'Abcissa'
336 m = ()
337 if self.data:
338 tlab = self.data._getsourcename(self._fittedrow)
339 xlab = self.data._getabcissalabel(self._fittedrow)
340 m = self.data._getmask(self._fittedrow)
341 ylab = self.data._get_ordinate_label()
342
343 colours = ["#777777","#bbbbbb","red","orange","purple","green","magenta", "cyan"]
344 self._p.palette(0,colours)
345 self._p.set_line(label='Spectrum')
346 self._p.plot(self.x, self.y, m)
347 if residual:
348 self._p.palette(1)
349 self._p.set_line(label='Residual')
350 self._p.plot(self.x, self.get_residual(), m)
351 self._p.palette(2)
352 if components is not None:
353 cs = components
354 if isinstance(components,int): cs = [components]
355 if plotparms:
356 self._p.text(0.15,0.15,str(self.get_parameters()[2]),size=8)
357 n = len(self.components)
358 self._p.palette(3)
359 for c in cs:
360 if 0 <= c < n:
361 lab = self.fitfuncs[c]+str(c)
362 self._p.set_line(label=lab)
363 self._p.plot(self.x, self.fitter.evaluate(c), m)
364 elif c == -1:
365 self._p.palette(2)
366 self._p.set_line(label="Total Fit")
367 self._p.plot(self.x, self.get_fit(), m)
368 else:
369 self._p.palette(2)
370 self._p.set_line(label='Fit')
371 self._p.plot(self.x, self.get_fit(), m)
372 self._p.set_axes('xlabel',xlab)
373 self._p.set_axes('ylabel',ylab)
374 self._p.set_axes('title',tlab)
375 self._p.release()
376
377 def auto_fit(self, insitu=None):
378 """
379 Return a scan where the function is applied to all rows for
380 all Beams/IFs/Pols.
381
382 """
383 from asap import scantable
384 if not isinstance(self.data, scantable) :
385 print "Only works with scantables"
386 return
387 if insitu is None: insitu = rcParams['insitu']
388 if not insitu:
389 scan = self.data.copy()
390 else:
391 scan = self.data
392 vb = scan._vb
393 scan._vb = False
394 sel = scan.get_cursor()
395 rows = range(scan.nrow())
396 for i in range(scan.nbeam()):
397 scan.setbeam(i)
398 for j in range(scan.nif()):
399 scan.setif(j)
400 for k in range(scan.npol()):
401 scan.setpol(k)
402 if self._vb:
403 print "Fitting:"
404 print 'Beam[%d], IF[%d], Pol[%d]' % (i,j,k)
405 for iRow in rows:
406 self.x = scan._getabcissa(iRow)
407 self.y = scan._getspectrum(iRow)
408 self.data = None
409 self.fit()
410 x = self.get_parameters()
411 scan._setspectrum(self.fitter.getresidual(),iRow)
412 scan.set_cursor(sel[0],sel[1],sel[2])
413 scan._vb = vb
414 if not insitu:
415 return scan
Note: See TracBrowser for help on using the repository browser.