source: trunk/python/asaplotbase.py @ 1100

Last change on this file since 1100 was 1100, checked in by mar637, 18 years ago

removed hardcoding of legend font size

  • Property svn:eol-style set to native
  • Property svn:keywords set to Author Date Id Revision
File size: 23.8 KB
Line 
1"""
2ASAP plotting class based on matplotlib.
3"""
4
5import sys
6from re import match
7
8import matplotlib
9
10from matplotlib.figure import Figure, Text
11from matplotlib.font_manager import FontProperties
12from matplotlib.numerix import sqrt
13from matplotlib import rc, rcParams
14from asap import rcParams as asaprcParams
15from matplotlib.ticker import ScalarFormatter
16from matplotlib.ticker import NullLocator
17if int(matplotlib.__version__.split(".")[1]) < 87:
18    print "Warning: matplotlib version < 0.87. This might cause errors. Please upgrade."
19
20class MyFormatter(ScalarFormatter):
21    def __call__(self, x, pos=None):
22        #last = len(self.locs)-2
23        if pos==0:
24            return ''
25        else: return ScalarFormatter.__call__(self, x, pos)
26
27class asaplotbase:
28    """
29    ASAP plotting base class based on matplotlib.
30    """
31
32    def __init__(self, rows=1, cols=0, title='', size=(8,6), buffering=False):
33        """
34        Create a new instance of the ASAPlot plotting class.
35
36        If rows < 1 then a separate call to set_panels() is required to define
37        the panel layout; refer to the doctext for set_panels().
38        """
39        self.is_dead = False
40        self.figure = Figure(figsize=size, facecolor='#ddddee')
41        self.canvas = None
42
43        self.set_title(title)
44        self.subplots = []
45        if rows > 0:
46            self.set_panels(rows, cols)
47
48        # Set matplotlib default colour sequence.
49        self.colormap = "green red black cyan magenta orange blue purple yellow pink".split()
50
51        c = asaprcParams['plotter.colours']
52        if isinstance(c,str) and len(c) > 0:
53            self.colormap = c.split()
54
55        self.lsalias = {"line":  [1,0],
56                        "dashdot": [4,2,1,2],
57                        "dashed" : [4,2,4,2],
58                        "dotted" : [1,2],
59                        "dashdotdot": [4,2,1,2,1,2],
60                        "dashdashdot": [4,2,4,2,1,2]
61                        }
62
63        styles = "line dashed dotted dashdot".split()
64        c = asaprcParams['plotter.linestyles']
65        if isinstance(c,str) and len(c) > 0:
66            styles = c.split()
67        s = []
68        for ls in styles:
69            if self.lsalias.has_key(ls):
70                s.append(self.lsalias.get(ls))
71            else:
72                s.append('-')
73        self.linestyles = s
74
75        self.color = 0;
76        self.linestyle = 0;
77        self.attributes = {}
78        self.loc = 0
79
80        self.buffering = buffering
81
82    def clear(self):
83        """
84        Delete all lines from the plot.  Line numbering will restart from 1.
85        """
86
87        for i in range(len(self.lines)):
88           self.delete(i)
89        self.axes.clear()
90        self.color = 0
91        self.lines = []
92
93    def palette(self, color, colormap=None, linestyle=0, linestyles=None):
94        if colormap:
95            if isinstance(colormap,list):
96                self.colormap = colormap
97            elif isinstance(colormap,str):
98                self.colormap = colormap.split()
99        if 0 <= color < len(self.colormap):
100            self.color = color
101        if linestyles:
102            self.linestyles = []
103            if isinstance(linestyles,list):
104                styles = linestyles
105            elif isinstance(linestyles,str):
106                styles = linestyles.split()
107            for ls in styles:
108                if self.lsalias.has_key(ls):
109                    self.linestyles.append(self.lsalias.get(ls))
110                else:
111                    self.linestyles.append(self.lsalias.get('line'))
112        if 0 <= linestyle < len(self.linestyles):
113            self.linestyle = linestyle
114
115    def delete(self, numbers=None):
116        """
117        Delete the 0-relative line number, default is to delete the last.
118        The remaining lines are NOT renumbered.
119        """
120
121        if numbers is None: numbers = [len(self.lines)-1]
122
123        if not hasattr(numbers, '__iter__'):
124            numbers = [numbers]
125
126        for number in numbers:
127            if 0 <= number < len(self.lines):
128                if self.lines[number] is not None:
129                    for line in self.lines[number]:
130                        line.set_linestyle('None')
131                        self.lines[number] = None
132        self.show()
133
134    def get_line(self):
135        """
136        Get the current default line attributes.
137        """
138        return self.attributes
139
140
141    def hist(self, x=None, y=None, fmt=None, add=None):
142        """
143        Plot a histogram.  N.B. the x values refer to the start of the
144        histogram bin.
145
146        fmt is the line style as in plot().
147        """
148        from matplotlib.numerix import array
149        from matplotlib.numerix.ma import MaskedArray
150        if x is None:
151            if y is None: return
152            x = range(len(y))
153
154        if len(x) != len(y):
155            return
156        l2 = 2*len(x)
157        x2 = range(l2)
158        y2 = range(12)
159        y2 = range(l2)
160        m2 = range(l2)
161        ymsk = y.raw_mask()
162        ydat = y.raw_data()
163        for i in range(l2):
164            x2[i] = x[i/2]
165            m2[i] = ymsk[i/2]
166
167        y2[0] = 0.0
168        for i in range(1,l2):
169            y2[i] = ydat[(i-1)/2]
170
171        self.plot(x2, MaskedArray(y2,mask=m2,copy=0), fmt, add)
172
173
174    def hold(self, hold=True):
175        """
176        Buffer graphics until subsequently released.
177        """
178        self.buffering = hold
179
180
181    def legend(self, loc=None):
182        """
183        Add a legend to the plot.
184
185        Any other value for loc else disables the legend:
186             1: upper right
187             2: upper left
188             3: lower left
189             4: lower right
190             5: right
191             6: center left
192             7: center right
193             8: lower center
194             9: upper center
195            10: center
196
197        """
198        if isinstance(loc, int):
199            self.loc = None
200            if 0 <= loc <= 10: self.loc = loc
201        else:
202            self.loc = None
203        #self.show()
204
205
206    def plot(self, x=None, y=None, fmt=None, add=None):
207        """
208        Plot the next line in the current frame using the current line
209        attributes.  The ASAPlot graphics window will be mapped and raised.
210
211        The argument list works a bit like the matlab plot() function.
212        """
213        if x is None:
214            if y is None: return
215            x = range(len(y))
216
217        elif y is None:
218            y = x
219            x = range(len(y))
220        if fmt is None:
221            line = self.axes.plot(x, y)
222        else:
223            line = self.axes.plot(x, y, fmt)
224
225        # Add to an existing line?
226        i = None
227        if add is None or len(self.lines) < add < 0:
228            # Don't add.
229            self.lines.append(line)
230            i = len(self.lines) - 1
231        else:
232            if add == 0: add = len(self.lines)
233            i = add - 1
234            self.lines[i].extend(line)
235
236        # Set/reset attributes for the line.
237        gotcolour = False
238        for k, v in self.attributes.iteritems():
239            if k == 'color': gotcolour = True
240            for segment in self.lines[i]:
241                getattr(segment, "set_%s"%k)(v)
242
243        if not gotcolour and len(self.colormap):
244            for segment in self.lines[i]:
245                getattr(segment, "set_color")(self.colormap[self.color])
246                if len(self.colormap)  == 1:
247                    getattr(segment, "set_dashes")(self.linestyles[self.linestyle])
248
249            self.color += 1
250            if self.color >= len(self.colormap):
251                self.color = 0
252
253            if len(self.colormap) == 1:
254                self.linestyle += 1
255            if self.linestyle >= len(self.linestyles):
256                self.linestyle = 0
257
258        self.show()
259
260
261    def position(self):
262        """
263        Use the mouse to get a position from a graph.
264        """
265
266        def position_disable(event):
267            self.register('button_press', None)
268            print '%.4f, %.4f' % (event.xdata, event.ydata)
269
270        print 'Press any mouse button...'
271        self.register('button_press', position_disable)
272
273
274    def region(self):
275        """
276        Use the mouse to get a rectangular region from a plot.
277
278        The return value is [x0, y0, x1, y1] in world coordinates.
279        """
280
281        def region_start(event):
282            height = self.canvas.figure.bbox.height()
283            self.rect = {'fig': None, 'height': height,
284                         'x': event.x, 'y': height - event.y,
285                         'world': [event.xdata, event.ydata,
286                                   event.xdata, event.ydata]}
287            self.register('button_press', None)
288            self.register('motion_notify', region_draw)
289            self.register('button_release', region_disable)
290
291        def region_draw(event):
292            self.canvas._tkcanvas.delete(self.rect['fig'])
293            self.rect['fig'] = self.canvas._tkcanvas.create_rectangle(
294                                self.rect['x'], self.rect['y'],
295                                event.x, self.rect['height'] - event.y)
296
297        def region_disable(event):
298            self.register('motion_notify', None)
299            self.register('button_release', None)
300
301            self.canvas._tkcanvas.delete(self.rect['fig'])
302
303            self.rect['world'][2:4] = [event.xdata, event.ydata]
304            print '(%.2f, %.2f)  (%.2f, %.2f)' % (self.rect['world'][0],
305                self.rect['world'][1], self.rect['world'][2],
306                self.rect['world'][3])
307
308        self.register('button_press', region_start)
309
310        # This has to be modified to block and return the result (currently
311        # printed by region_disable) when that becomes possible in matplotlib.
312
313        return [0.0, 0.0, 0.0, 0.0]
314
315
316    def register(self, type=None, func=None):
317        """
318        Register, reregister, or deregister events of type 'button_press',
319        'button_release', or 'motion_notify'.
320
321        The specified callback function should have the following signature:
322
323            def func(event)
324
325        where event is an MplEvent instance containing the following data:
326
327            name                # Event name.
328            canvas              # FigureCanvas instance generating the event.
329            x      = None       # x position - pixels from left of canvas.
330            y      = None       # y position - pixels from bottom of canvas.
331            button = None       # Button pressed: None, 1, 2, 3.
332            key    = None       # Key pressed: None, chr(range(255)), shift,
333                                  win, or control
334            inaxes = None       # Axes instance if cursor within axes.
335            xdata  = None       # x world coordinate.
336            ydata  = None       # y world coordinate.
337
338        For example:
339
340            def mouse_move(event):
341                print event.xdata, event.ydata
342
343            a = asaplot()
344            a.register('motion_notify', mouse_move)
345
346        If func is None, the event is deregistered.
347
348        Note that in TkAgg keyboard button presses don't generate an event.
349        """
350
351        if not self.events.has_key(type): return
352
353        if func is None:
354            if self.events[type] is not None:
355                # It's not clear that this does anything.
356                self.canvas.mpl_disconnect(self.events[type])
357                self.events[type] = None
358
359                # It seems to be necessary to return events to the toolbar.
360                if type == 'motion_notify':
361                    self.canvas.mpl_connect(type + '_event',
362                        self.figmgr.toolbar.mouse_move)
363                elif type == 'button_press':
364                    self.canvas.mpl_connect(type + '_event',
365                        self.figmgr.toolbar.press)
366                elif type == 'button_release':
367                    self.canvas.mpl_connect(type + '_event',
368                        self.figmgr.toolbar.release)
369
370        else:
371            self.events[type] = self.canvas.mpl_connect(type + '_event', func)
372
373
374    def release(self):
375        """
376        Release buffered graphics.
377        """
378        self.buffering = False
379        self.show()
380
381
382    def save(self, fname=None, orientation=None, dpi=None, papertype=None):
383        """
384        Save the plot to a file.
385
386        fname is the name of the output file.  The image format is determined
387        from the file suffix; 'png', 'ps', and 'eps' are recognized.  If no
388        file name is specified 'yyyymmdd_hhmmss.png' is created in the current
389        directory.
390        """
391        from asap import rcParams
392        if papertype is None:
393            papertype = rcParams['plotter.papertype']
394        if fname is None:
395            from datetime import datetime
396            dstr = datetime.now().strftime('%Y%m%d_%H%M%S')
397            fname = 'asap'+dstr+'.png'
398
399        d = ['png','.ps','eps']
400
401        from os.path import expandvars
402        fname = expandvars(fname)
403
404        if fname[-3:].lower() in d:
405            try:
406                if fname[-3:].lower() == ".ps":
407                    from matplotlib import __version__ as mv
408                    w = self.figure.figwidth.get()
409                    h = self.figure.figheight.get()
410
411                    if orientation is None:
412                        # auto oriented
413                        if w > h:
414                            orientation = 'landscape'
415                        else:
416                            orientation = 'portrait'
417                    from matplotlib.backends.backend_ps import papersize
418                    pw,ph = papersize[papertype.lower()]
419                    ds = None
420                    if orientation == 'landscape':
421                        ds = min(ph/w, pw/h)
422                    else:
423                        ds = min(pw/w, ph/h)
424                    ow = ds * w
425                    oh = ds * h
426                    self.figure.set_figsize_inches((ow, oh))
427                    self.figure.savefig(fname, orientation=orientation,
428                                        papertype=papertype.lower())
429                    self.figure.set_figsize_inches((w, h))
430                    print 'Written file %s' % (fname)
431                else:
432                    if dpi is None:
433                        dpi =150
434                    self.figure.savefig(fname,dpi=dpi)
435                    print 'Written file %s' % (fname)
436            except IOError, msg:
437                print 'Failed to save %s: Error msg was\n\n%s' % (fname, err)
438                return
439        else:
440            print "Invalid image type. Valid types are:"
441            print "'ps', 'eps', 'png'"
442
443
444    def set_axes(self, what=None, *args, **kwargs):
445        """
446        Set attributes for the axes by calling the relevant Axes.set_*()
447        method.  Colour translation is done as described in the doctext
448        for palette().
449        """
450
451        if what is None: return
452        if what[-6:] == 'colour': what = what[:-6] + 'color'
453
454        newargs = {}
455
456        for k, v in kwargs.iteritems():
457            k = k.lower()
458            if k == 'colour': k = 'color'
459            newargs[k] = v
460
461        getattr(self.axes, "set_%s"%what)(*args, **newargs)
462
463        self.show()
464
465
466    def set_figure(self, what=None, *args, **kwargs):
467        """
468        Set attributes for the figure by calling the relevant Figure.set_*()
469        method.  Colour translation is done as described in the doctext
470        for palette().
471        """
472
473        if what is None: return
474        if what[-6:] == 'colour': what = what[:-6] + 'color'
475        #if what[-5:] == 'color' and len(args):
476        #    args = (get_colour(args[0]),)
477
478        newargs = {}
479        for k, v in kwargs.iteritems():
480            k = k.lower()
481            if k == 'colour': k = 'color'
482            newargs[k] = v
483
484        getattr(self.figure, "set_%s"%what)(*args, **newargs)
485        self.show()
486
487
488    def set_limits(self, xlim=None, ylim=None):
489        """
490        Set x-, and y-limits for each subplot.
491
492        xlim = [xmin, xmax] as in axes.set_xlim().
493        ylim = [ymin, ymax] as in axes.set_ylim().
494        """
495        for s in self.subplots:
496            self.axes  = s['axes']
497            self.lines = s['lines']
498            oldxlim =  list(self.axes.get_xlim())
499            oldylim =  list(self.axes.get_ylim())
500            if xlim is not None:
501                for i in range(len(xlim)):
502                    if xlim[i] is not None:
503                        oldxlim[i] = xlim[i]
504            if ylim is not None:
505                for i in range(len(ylim)):
506                    if ylim[i] is not None:
507                        oldylim[i] = ylim[i]
508            self.axes.set_xlim(oldxlim)
509            self.axes.set_ylim(oldylim)
510        return
511
512
513    def set_line(self, number=None, **kwargs):
514        """
515        Set attributes for the specified line, or else the next line(s)
516        to be plotted.
517
518        number is the 0-relative number of a line that has already been
519        plotted.  If no such line exists, attributes are recorded and used
520        for the next line(s) to be plotted.
521
522        Keyword arguments specify Line2D attributes, e.g. color='r'.  Do
523
524            import matplotlib
525            help(matplotlib.lines)
526
527        The set_* methods of class Line2D define the attribute names and
528        values.  For non-US usage, "colour" is recognized as synonymous with
529        "color".
530
531        Set the value to None to delete an attribute.
532
533        Colour translation is done as described in the doctext for palette().
534        """
535
536        redraw = False
537        for k, v in kwargs.iteritems():
538            k = k.lower()
539            if k == 'colour': k = 'color'
540
541            if 0 <= number < len(self.lines):
542                if self.lines[number] is not None:
543                    for line in self.lines[number]:
544                        getattr(line, "set_%s"%k)(v)
545                    redraw = True
546            else:
547                if v is None:
548                    del self.attributes[k]
549                else:
550                    self.attributes[k] = v
551
552        if redraw: self.show()
553
554
555    def set_panels(self, rows=1, cols=0, n=-1, nplots=-1, ganged=True):
556        """
557        Set the panel layout.
558
559        rows and cols, if cols != 0, specify the number of rows and columns in
560        a regular layout.   (Indexing of these panels in matplotlib is row-
561        major, i.e. column varies fastest.)
562
563        cols == 0 is interpreted as a retangular layout that accomodates
564        'rows' panels, e.g. rows == 6, cols == 0 is equivalent to
565        rows == 2, cols == 3.
566
567        0 <= n < rows*cols is interpreted as the 0-relative panel number in
568        the configuration specified by rows and cols to be added to the
569        current figure as its next 0-relative panel number (i).  This allows
570        non-regular panel layouts to be constructed via multiple calls.  Any
571        other value of n clears the plot and produces a rectangular array of
572        empty panels.  The number of these may be limited by nplots.
573        """
574        if n < 0 and len(self.subplots):
575            self.figure.clear()
576            self.set_title()
577
578        if rows < 1: rows = 1
579
580        if cols <= 0:
581            i = int(sqrt(rows))
582            if i*i < rows: i += 1
583            cols = i
584
585            if i*(i-1) >= rows: i -= 1
586            rows = i
587
588        if 0 <= n < rows*cols:
589            i = len(self.subplots)
590            self.subplots.append({})
591
592            self.subplots[i]['axes']  = self.figure.add_subplot(rows,
593                                            cols, n+1)
594            self.subplots[i]['lines'] = []
595
596            if i == 0: self.subplot(0)
597
598            self.rows = 0
599            self.cols = 0
600
601        else:
602            self.subplots = []
603
604            if nplots < 1 or rows*cols < nplots:
605                nplots = rows*cols
606            if ganged:
607                hsp,wsp = None,None
608                if rows > 1: hsp = 0.0001
609                if cols > 1: wsp = 0.0001
610                self.figure.subplots_adjust(wspace=wsp,hspace=hsp)
611            for i in range(nplots):
612                self.subplots.append({})
613                self.subplots[i]['axes'] = self.figure.add_subplot(rows,
614                                                cols, i+1)
615                self.subplots[i]['lines'] = []
616
617                if ganged:
618                    # Suppress tick labelling for interior subplots.
619                    if i <= (rows-1)*cols - 1:
620                        if i+cols < nplots:
621                            # Suppress x-labels for frames width
622                            # adjacent frames
623                            self.subplots[i]['axes'].xaxis.set_major_locator(NullLocator())
624                            self.subplots[i]['axes'].xaxis.label.set_visible(False)
625                    if i%cols:
626                        # Suppress y-labels for frames not in the left column.
627                        for tick in self.subplots[i]['axes'].yaxis.majorTicks:
628                            tick.label1On = False
629                        self.subplots[i]['axes'].yaxis.label.set_visible(False)
630                    # disable the first tick of [1:ncol-1] of the last row
631                    if (nplots-cols) < i <= nplots-1:
632                        self.subplots[i]['axes'].xaxis.set_major_formatter(MyFormatter())
633                self.rows = rows
634                self.cols = cols
635            self.subplot(0)
636
637    def set_title(self, title=None):
638        """
639        Set the title of the plot window.  Use the previous title if title is
640        omitted.
641        """
642        if title is not None:
643            self.title = title
644
645        self.figure.text(0.5, 0.95, self.title, horizontalalignment='center')
646
647
648    def show(self):
649        """
650        Show graphics dependent on the current buffering state.
651        """
652        if not self.buffering:
653            if self.loc is not None:
654                for sp in self.subplots:
655                    lines  = []
656                    labels = []
657                    i = 0
658                    for line in sp['lines']:
659                        i += 1
660                        if line is not None:
661                            lines.append(line[0])
662                            lbl = line[0].get_label()
663                            if lbl == '':
664                                lbl = str(i)
665                            labels.append(lbl)
666
667                    if len(lines):
668                        sp['axes'].legend(tuple(lines), tuple(labels),
669                                          self.loc)
670##                                           ,prop=FontProperties(size=lsiz) )
671                    else:
672                        sp['axes'].legend((' '))
673
674            from matplotlib.artist import setp
675            xts = rcParams['xtick.labelsize']-(self.cols)/2
676            yts = rcParams['ytick.labelsize']-(self.rows)/2
677            for sp in self.subplots:
678                ax = sp['axes']
679                s = ax.title.get_size()
680                tsize = s-(self.cols+self.rows)
681                ax.title.set_size(tsize)
682                setp(ax.get_xticklabels(), fontsize=xts)
683                setp(ax.get_yticklabels(), fontsize=yts)
684                origx = rcParams['axes.labelsize'] #ax.xaxis.label.get_size()
685                origy = rcParams['axes.labelsize'] #ax.yaxis.label.get_size()
686                off = 0
687                if self.cols > 1: off = self.cols
688                xfsize = origx-off
689                #rc('xtick',labelsize=xfsize)
690                ax.xaxis.label.set_size(xfsize)
691                off = 0
692                if self.rows > 1: off = self.rows
693                yfsize = origy-off
694                #rc('ytick',labelsize=yfsize)
695                ax.yaxis.label.set_size(yfsize)
696
697    def subplot(self, i=None, inc=None):
698        """
699        Set the subplot to the 0-relative panel number as defined by one or
700        more invokations of set_panels().
701        """
702        l = len(self.subplots)
703        if l:
704            if i is not None:
705                self.i = i
706
707            if inc is not None:
708                self.i += inc
709
710            self.i %= l
711            self.axes  = self.subplots[self.i]['axes']
712            self.lines = self.subplots[self.i]['lines']
713
714
715    def text(self, *args, **kwargs):
716        """
717        Add text to the figure.
718        """
719        self.figure.text(*args, **kwargs)
720        self.show()
Note: See TracBrowser for help on using the repository browser.