source: trunk/python/asaplotbase.py @ 1086

Last change on this file since 1086 was 1086, checked in by mar637, 18 years ago

use MA instead of homemade masking; some work on aut-scaling label size. This probably breaks old (woody) versions of mpl

  • Property svn:eol-style set to native
  • Property svn:keywords set to Author Date Id Revision
File size: 23.8 KB
Line 
1"""
2ASAP plotting class based on matplotlib.
3"""
4
5import sys
6from re import match
7
8import matplotlib
9
10from matplotlib.figure import Figure, Text
11from matplotlib.font_manager import FontProperties
12from matplotlib.numerix import sqrt
13from matplotlib import rc, rcParams
14from asap import rcParams as asaprcParams
15from matplotlib.ticker import ScalarFormatter
16from matplotlib.ticker import NullLocator
17
18class MyFormatter(ScalarFormatter):
19    def __call__(self, x, pos=None):
20        #last = len(self.locs)-2
21        if pos==0:
22            return ''
23        else: return ScalarFormatter.__call__(self, x, pos)
24
25class asaplotbase:
26    """
27    ASAP plotting base class based on matplotlib.
28    """
29
30    def __init__(self, rows=1, cols=0, title='', size=(8,6), buffering=False):
31        """
32        Create a new instance of the ASAPlot plotting class.
33
34        If rows < 1 then a separate call to set_panels() is required to define
35        the panel layout; refer to the doctext for set_panels().
36        """
37        self.is_dead = False
38        self.figure = Figure(figsize=size, facecolor='#ddddee')
39        self.canvas = None
40
41        self.set_title(title)
42        self.subplots = []
43        if rows > 0:
44            self.set_panels(rows, cols)
45
46        # Set matplotlib default colour sequence.
47        self.colormap = "green red black cyan magenta orange blue purple yellow pink".split()
48
49        c = asaprcParams['plotter.colours']
50        if isinstance(c,str) and len(c) > 0:
51            self.colormap = c.split()
52
53        self.lsalias = {"line":  [1,0],
54                        "dashdot": [4,2,1,2],
55                        "dashed" : [4,2,4,2],
56                        "dotted" : [1,2],
57                        "dashdotdot": [4,2,1,2,1,2],
58                        "dashdashdot": [4,2,4,2,1,2]
59                        }
60
61        styles = "line dashed dotted dashdot".split()
62        c = asaprcParams['plotter.linestyles']
63        if isinstance(c,str) and len(c) > 0:
64            styles = c.split()
65        s = []
66        for ls in styles:
67            if self.lsalias.has_key(ls):
68                s.append(self.lsalias.get(ls))
69            else:
70                s.append('-')
71        self.linestyles = s
72
73        self.color = 0;
74        self.linestyle = 0;
75        self.attributes = {}
76        self.loc = 0
77
78        self.buffering = buffering
79
80    def clear(self):
81        """
82        Delete all lines from the plot.  Line numbering will restart from 1.
83        """
84
85        for i in range(len(self.lines)):
86           self.delete(i)
87        self.axes.clear()
88        self.color = 0
89        self.lines = []
90
91    def palette(self, color, colormap=None, linestyle=0, linestyles=None):
92        if colormap:
93            if isinstance(colormap,list):
94                self.colormap = colormap
95            elif isinstance(colormap,str):
96                self.colormap = colormap.split()
97        if 0 <= color < len(self.colormap):
98            self.color = color
99        if linestyles:
100            self.linestyles = []
101            if isinstance(linestyles,list):
102                styles = linestyles
103            elif isinstance(linestyles,str):
104                styles = linestyles.split()
105            for ls in styles:
106                if self.lsalias.has_key(ls):
107                    self.linestyles.append(self.lsalias.get(ls))
108                else:
109                    self.linestyles.append(self.lsalias.get('line'))
110        if 0 <= linestyle < len(self.linestyles):
111            self.linestyle = linestyle
112
113    def delete(self, numbers=None):
114        """
115        Delete the 0-relative line number, default is to delete the last.
116        The remaining lines are NOT renumbered.
117        """
118
119        if numbers is None: numbers = [len(self.lines)-1]
120
121        if not hasattr(numbers, '__iter__'):
122            numbers = [numbers]
123
124        for number in numbers:
125            if 0 <= number < len(self.lines):
126                if self.lines[number] is not None:
127                    for line in self.lines[number]:
128                        line.set_linestyle('None')
129                        self.lines[number] = None
130        self.show()
131
132    def get_line(self):
133        """
134        Get the current default line attributes.
135        """
136        return self.attributes
137
138
139    def hist(self, x=None, y=None, fmt=None, add=None):
140        """
141        Plot a histogram.  N.B. the x values refer to the start of the
142        histogram bin.
143
144        fmt is the line style as in plot().
145        """
146        from matplotlib.numerix import array
147        from matplotlib.numerix.ma import MaskedArray
148        if x is None:
149            if y is None: return
150            x = range(len(y))
151
152        if len(x) != len(y):
153            return
154        l2 = 2*len(x)
155        x2 = range(l2)
156        y2 = range(12)
157        y2 = range(l2)
158        m2 = range(l2)
159        ymsk = y.raw_mask()
160        ydat = y.raw_data()
161        for i in range(l2):
162            x2[i] = x[i/2]
163            m2[i] = ymsk[i/2]
164
165        y2[0] = 0.0
166        for i in range(1,l2):
167            y2[i] = ydat[(i-1)/2]
168
169        self.plot(x2, MaskedArray(y2,mask=m2,copy=0), fmt, add)
170
171
172    def hold(self, hold=True):
173        """
174        Buffer graphics until subsequently released.
175        """
176        self.buffering = hold
177
178
179    def legend(self, loc=None):
180        """
181        Add a legend to the plot.
182
183        Any other value for loc else disables the legend:
184             1: upper right
185             2: upper left
186             3: lower left
187             4: lower right
188             5: right
189             6: center left
190             7: center right
191             8: lower center
192             9: upper center
193            10: center
194
195        """
196        if isinstance(loc,int):
197            if 0 > loc > 10: loc = 0
198            self.loc = loc
199        self.show()
200
201
202    def plot(self, x=None, y=None, fmt=None, add=None):
203        """
204        Plot the next line in the current frame using the current line
205        attributes.  The ASAPlot graphics window will be mapped and raised.
206
207        The argument list works a bit like the matlab plot() function.
208        """
209        if x is None:
210            if y is None: return
211            x = range(len(y))
212
213        elif y is None:
214            y = x
215            x = range(len(y))
216        if fmt is None:
217            line = self.axes.plot(x, y)
218        else:
219            line = self.axes.plot(x, y, fmt)
220
221        # Add to an existing line?
222        i = None
223        if add is None or len(self.lines) < add < 0:
224            # Don't add.
225            self.lines.append(line)
226            i = len(self.lines) - 1
227        else:
228            if add == 0: add = len(self.lines)
229            i = add - 1
230            self.lines[i].extend(line)
231
232        # Set/reset attributes for the line.
233        gotcolour = False
234        for k, v in self.attributes.iteritems():
235            if k == 'color': gotcolour = True
236            for segment in self.lines[i]:
237                getattr(segment, "set_%s"%k)(v)
238
239        if not gotcolour and len(self.colormap):
240            for segment in self.lines[i]:
241                getattr(segment, "set_color")(self.colormap[self.color])
242                if len(self.colormap)  == 1:
243                    getattr(segment, "set_dashes")(self.linestyles[self.linestyle])
244
245            self.color += 1
246            if self.color >= len(self.colormap):
247                self.color = 0
248
249            if len(self.colormap) == 1:
250                self.linestyle += 1
251            if self.linestyle >= len(self.linestyles):
252                self.linestyle = 0
253
254        self.show()
255
256
257    def position(self):
258        """
259        Use the mouse to get a position from a graph.
260        """
261
262        def position_disable(event):
263            self.register('button_press', None)
264            print '%.4f, %.4f' % (event.xdata, event.ydata)
265
266        print 'Press any mouse button...'
267        self.register('button_press', position_disable)
268
269
270    def region(self):
271        """
272        Use the mouse to get a rectangular region from a plot.
273
274        The return value is [x0, y0, x1, y1] in world coordinates.
275        """
276
277        def region_start(event):
278            height = self.canvas.figure.bbox.height()
279            self.rect = {'fig': None, 'height': height,
280                         'x': event.x, 'y': height - event.y,
281                         'world': [event.xdata, event.ydata,
282                                   event.xdata, event.ydata]}
283            self.register('button_press', None)
284            self.register('motion_notify', region_draw)
285            self.register('button_release', region_disable)
286
287        def region_draw(event):
288            self.canvas._tkcanvas.delete(self.rect['fig'])
289            self.rect['fig'] = self.canvas._tkcanvas.create_rectangle(
290                                self.rect['x'], self.rect['y'],
291                                event.x, self.rect['height'] - event.y)
292
293        def region_disable(event):
294            self.register('motion_notify', None)
295            self.register('button_release', None)
296
297            self.canvas._tkcanvas.delete(self.rect['fig'])
298
299            self.rect['world'][2:4] = [event.xdata, event.ydata]
300            print '(%.2f, %.2f)  (%.2f, %.2f)' % (self.rect['world'][0],
301                self.rect['world'][1], self.rect['world'][2],
302                self.rect['world'][3])
303
304        self.register('button_press', region_start)
305
306        # This has to be modified to block and return the result (currently
307        # printed by region_disable) when that becomes possible in matplotlib.
308
309        return [0.0, 0.0, 0.0, 0.0]
310
311
312    def register(self, type=None, func=None):
313        """
314        Register, reregister, or deregister events of type 'button_press',
315        'button_release', or 'motion_notify'.
316
317        The specified callback function should have the following signature:
318
319            def func(event)
320
321        where event is an MplEvent instance containing the following data:
322
323            name                # Event name.
324            canvas              # FigureCanvas instance generating the event.
325            x      = None       # x position - pixels from left of canvas.
326            y      = None       # y position - pixels from bottom of canvas.
327            button = None       # Button pressed: None, 1, 2, 3.
328            key    = None       # Key pressed: None, chr(range(255)), shift,
329                                  win, or control
330            inaxes = None       # Axes instance if cursor within axes.
331            xdata  = None       # x world coordinate.
332            ydata  = None       # y world coordinate.
333
334        For example:
335
336            def mouse_move(event):
337                print event.xdata, event.ydata
338
339            a = asaplot()
340            a.register('motion_notify', mouse_move)
341
342        If func is None, the event is deregistered.
343
344        Note that in TkAgg keyboard button presses don't generate an event.
345        """
346
347        if not self.events.has_key(type): return
348
349        if func is None:
350            if self.events[type] is not None:
351                # It's not clear that this does anything.
352                self.canvas.mpl_disconnect(self.events[type])
353                self.events[type] = None
354
355                # It seems to be necessary to return events to the toolbar.
356                if type == 'motion_notify':
357                    self.canvas.mpl_connect(type + '_event',
358                        self.figmgr.toolbar.mouse_move)
359                elif type == 'button_press':
360                    self.canvas.mpl_connect(type + '_event',
361                        self.figmgr.toolbar.press)
362                elif type == 'button_release':
363                    self.canvas.mpl_connect(type + '_event',
364                        self.figmgr.toolbar.release)
365
366        else:
367            self.events[type] = self.canvas.mpl_connect(type + '_event', func)
368
369
370    def release(self):
371        """
372        Release buffered graphics.
373        """
374        self.buffering = False
375        self.show()
376
377
378    def save(self, fname=None, orientation=None, dpi=None):
379        """
380        Save the plot to a file.
381
382        fname is the name of the output file.  The image format is determined
383        from the file suffix; 'png', 'ps', and 'eps' are recognized.  If no
384        file name is specified 'yyyymmdd_hhmmss.png' is created in the current
385        directory.
386        """
387        if fname is None:
388            from datetime import datetime
389            dstr = datetime.now().strftime('%Y%m%d_%H%M%S')
390            fname = 'asap'+dstr+'.png'
391
392        d = ['png','.ps','eps']
393
394        from os.path import expandvars
395        fname = expandvars(fname)
396
397        if fname[-3:].lower() in d:
398            try:
399                if fname[-3:].lower() == ".ps":
400                    from matplotlib import __version__ as mv
401                    w = self.figure.figwidth.get()
402                    h = self.figure.figheight.get()
403
404                    if orientation is None:
405                        # auto oriented
406                        if w > h:
407                            orientation = 'landscape'
408                        else:
409                            orientation = 'portrait'
410                    a4w = 8.25
411                    a4h = 11.25
412                    ds = None
413                    if orientation == 'landscape':
414                        ds = min(a4h/w,a4w/h)
415                    else:
416                        ds = min(a4w/w,a4h/h)
417                    ow = ds * w
418                    oh = ds * h
419                    self.figure.set_figsize_inches((ow,oh))
420                    from matplotlib import __version__ as mv
421                    # hack to circument ps bug in eraly versions of mpl
422                    if int(mv.split(".")[1]) < 87:
423                        self.figure.savefig(fname, orientation=orientation)
424                    else:
425                        self.figure.savefig(fname, orientation=orientation,
426                                            papertype="a4")
427                    self.figure.set_figsize_inches((w,h))
428                    print 'Written file %s' % (fname)
429                else:
430                    if dpi is None:
431                        dpi =150
432                    self.figure.savefig(fname,dpi=dpi)
433                    print 'Written file %s' % (fname)
434            except IOError, msg:
435                print 'Failed to save %s: Error msg was\n\n%s' % (fname, err)
436                return
437        else:
438            print "Invalid image type. Valid types are:"
439            print "'ps', 'eps', 'png'"
440
441
442    def set_axes(self, what=None, *args, **kwargs):
443        """
444        Set attributes for the axes by calling the relevant Axes.set_*()
445        method.  Colour translation is done as described in the doctext
446        for palette().
447        """
448
449        if what is None: return
450        if what[-6:] == 'colour': what = what[:-6] + 'color'
451
452        newargs = {}
453
454        for k, v in kwargs.iteritems():
455            k = k.lower()
456            if k == 'colour': k = 'color'
457            newargs[k] = v
458
459        getattr(self.axes, "set_%s"%what)(*args, **newargs)
460
461        self.show()
462
463
464    def set_figure(self, what=None, *args, **kwargs):
465        """
466        Set attributes for the figure by calling the relevant Figure.set_*()
467        method.  Colour translation is done as described in the doctext
468        for palette().
469        """
470
471        if what is None: return
472        if what[-6:] == 'colour': what = what[:-6] + 'color'
473        #if what[-5:] == 'color' and len(args):
474        #    args = (get_colour(args[0]),)
475
476        newargs = {}
477        for k, v in kwargs.iteritems():
478            k = k.lower()
479            if k == 'colour': k = 'color'
480            newargs[k] = v
481
482        getattr(self.figure, "set_%s"%what)(*args, **newargs)
483        self.show()
484
485
486    def set_limits(self, xlim=None, ylim=None):
487        """
488        Set x-, and y-limits for each subplot.
489
490        xlim = [xmin, xmax] as in axes.set_xlim().
491        ylim = [ymin, ymax] as in axes.set_ylim().
492        """
493        for s in self.subplots:
494            self.axes  = s['axes']
495            self.lines = s['lines']
496            oldxlim =  list(self.axes.get_xlim())
497            oldylim =  list(self.axes.get_ylim())
498            if xlim is not None:
499                for i in range(len(xlim)):
500                    if xlim[i] is not None:
501                        oldxlim[i] = xlim[i]
502            if ylim is not None:
503                for i in range(len(ylim)):
504                    if ylim[i] is not None:
505                        oldylim[i] = ylim[i]
506            self.axes.set_xlim(oldxlim)
507            self.axes.set_ylim(oldylim)
508        return
509
510
511    def set_line(self, number=None, **kwargs):
512        """
513        Set attributes for the specified line, or else the next line(s)
514        to be plotted.
515
516        number is the 0-relative number of a line that has already been
517        plotted.  If no such line exists, attributes are recorded and used
518        for the next line(s) to be plotted.
519
520        Keyword arguments specify Line2D attributes, e.g. color='r'.  Do
521
522            import matplotlib
523            help(matplotlib.lines)
524
525        The set_* methods of class Line2D define the attribute names and
526        values.  For non-US usage, "colour" is recognized as synonymous with
527        "color".
528
529        Set the value to None to delete an attribute.
530
531        Colour translation is done as described in the doctext for palette().
532        """
533
534        redraw = False
535        for k, v in kwargs.iteritems():
536            k = k.lower()
537            if k == 'colour': k = 'color'
538
539            if 0 <= number < len(self.lines):
540                if self.lines[number] is not None:
541                    for line in self.lines[number]:
542                        getattr(line, "set_%s"%k)(v)
543                    redraw = True
544            else:
545                if v is None:
546                    del self.attributes[k]
547                else:
548                    self.attributes[k] = v
549
550        if redraw: self.show()
551
552
553    def set_panels(self, rows=1, cols=0, n=-1, nplots=-1, ganged=True):
554        """
555        Set the panel layout.
556
557        rows and cols, if cols != 0, specify the number of rows and columns in
558        a regular layout.   (Indexing of these panels in matplotlib is row-
559        major, i.e. column varies fastest.)
560
561        cols == 0 is interpreted as a retangular layout that accomodates
562        'rows' panels, e.g. rows == 6, cols == 0 is equivalent to
563        rows == 2, cols == 3.
564
565        0 <= n < rows*cols is interpreted as the 0-relative panel number in
566        the configuration specified by rows and cols to be added to the
567        current figure as its next 0-relative panel number (i).  This allows
568        non-regular panel layouts to be constructed via multiple calls.  Any
569        other value of n clears the plot and produces a rectangular array of
570        empty panels.  The number of these may be limited by nplots.
571        """
572        if n < 0 and len(self.subplots):
573            self.figure.clear()
574            self.set_title()
575
576        if rows < 1: rows = 1
577
578        if cols <= 0:
579            i = int(sqrt(rows))
580            if i*i < rows: i += 1
581            cols = i
582
583            if i*(i-1) >= rows: i -= 1
584            rows = i
585
586        if 0 <= n < rows*cols:
587            i = len(self.subplots)
588            self.subplots.append({})
589
590            self.subplots[i]['axes']  = self.figure.add_subplot(rows,
591                                            cols, n+1)
592            self.subplots[i]['lines'] = []
593
594            if i == 0: self.subplot(0)
595
596            self.rows = 0
597            self.cols = 0
598
599        else:
600            self.subplots = []
601
602            if nplots < 1 or rows*cols < nplots:
603                nplots = rows*cols
604            if ganged:
605                hsp,wsp = None,None
606                if rows > 1: hsp = 0.0001
607                if cols > 1: wsp = 0.0001
608                self.figure.subplots_adjust(wspace=wsp,hspace=hsp)
609            for i in range(nplots):
610                self.subplots.append({})
611                self.subplots[i]['axes'] = self.figure.add_subplot(rows,
612                                                cols, i+1)
613                self.subplots[i]['lines'] = []
614
615                if ganged:
616                    # Suppress tick labelling for interior subplots.
617                    if i <= (rows-1)*cols - 1:
618                        if i+cols < nplots:
619                            # Suppress x-labels for frames width
620                            # adjacent frames
621                            self.subplots[i]['axes'].xaxis.set_major_locator(NullLocator())
622                            self.subplots[i]['axes'].xaxis.label.set_visible(False)
623                    if i%cols:
624                        # Suppress y-labels for frames not in the left column.
625                        for tick in self.subplots[i]['axes'].yaxis.majorTicks:
626                            tick.label1On = False
627                        self.subplots[i]['axes'].yaxis.label.set_visible(False)
628                    # disable the first tick of [1:ncol-1] of the last row
629                    if (nplots-cols) < i <= nplots-1:
630                        self.subplots[i]['axes'].xaxis.set_major_formatter(MyFormatter())
631                self.rows = rows
632                self.cols = cols
633            self.subplot(0)
634
635    def set_title(self, title=None):
636        """
637        Set the title of the plot window.  Use the previous title if title is
638        omitted.
639        """
640        if title is not None:
641            self.title = title
642
643        self.figure.text(0.5, 0.95, self.title, horizontalalignment='center')
644
645
646    def show(self):
647        """
648        Show graphics dependent on the current buffering state.
649        """
650        if not self.buffering:
651            if self.loc is not None:
652                for sp in self.subplots:
653                    lines  = []
654                    labels = []
655                    i = 0
656                    for line in sp['lines']:
657                        i += 1
658                        if line is not None:
659                            lines.append(line[0])
660                            lbl = line[0].get_label()
661                            if lbl == '':
662                                lbl = str(i)
663                            labels.append(lbl)
664
665                    if len(lines):
666                        rcParams['legend.fontsize'] = 8
667##                         lsiz = rcParams['legend.fontsize']-len(lines)/2
668                        sp['axes'].legend(tuple(lines), tuple(labels),
669                                          self.loc)
670##                                           ,prop=FontProperties(size=lsiz) )
671                    else:
672                        sp['axes'].legend((' '))
673
674            ax = sp['axes']
675            from matplotlib.artist import setp
676            xts = rcParams['xtick.labelsize']-(self.cols)/2
677            yts = rcParams['ytick.labelsize']-(self.rows)/2
678            for sp in self.subplots:
679                ax = sp['axes']
680                s = ax.title.get_size()
681                tsize = s-(self.cols+self.rows)
682                ax.title.set_size(tsize)
683                setp(ax.get_xticklabels(), fontsize=xts)
684                setp(ax.get_yticklabels(), fontsize=yts)
685                origx = rcParams['xtick.labelsize'] #ax.xaxis.label.get_size()
686                origy =  rcParams['ytick.labelsize'] #ax.yaxis.label.get_size()
687                off = 0
688                if self.cols > 1: off = self.cols
689                xfsize = origx-off
690                #rc('xtick',labelsize=xfsize)
691                ax.xaxis.label.set_size(xfsize)
692                off = 0
693                if self.rows > 1: off = self.rows
694                yfsize = origy-off
695                #rc('ytick',labelsize=yfsize)
696                ax.yaxis.label.set_size(yfsize)
697
698    def subplot(self, i=None, inc=None):
699        """
700        Set the subplot to the 0-relative panel number as defined by one or
701        more invokations of set_panels().
702        """
703        l = len(self.subplots)
704        if l:
705            if i is not None:
706                self.i = i
707
708            if inc is not None:
709                self.i += inc
710
711            self.i %= l
712            self.axes  = self.subplots[self.i]['axes']
713            self.lines = self.subplots[self.i]['lines']
714
715
716    def text(self, *args, **kwargs):
717        """
718        Add text to the figure.
719        """
720        self.figure.text(*args, **kwargs)
721        self.show()
Note: See TracBrowser for help on using the repository browser.