source: trunk/python/asaplotbase.py @ 1025

Last change on this file since 1025 was 1025, checked in by mar637, 18 years ago

a stab at making axis ticks and labels scale in multipanelling. Reverted the version depend save addition.

  • Property svn:eol-style set to native
  • Property svn:keywords set to Author Date Id Revision
File size: 23.3 KB
Line 
1"""
2ASAP plotting class based on matplotlib.
3"""
4
5import sys
6from re import match
7
8import matplotlib
9
10from matplotlib.figure import Figure, Text
11from matplotlib.font_manager import FontProperties
12from matplotlib.numerix import sqrt
13from matplotlib import rc, rcParams
14from asap import rcParams as asaprcParams
15from matplotlib.ticker import ScalarFormatter
16from matplotlib.ticker import NullLocator
17
18class MyFormatter(ScalarFormatter):
19    def __call__(self, x, pos=None):
20        #last = len(self.locs)-2
21        if pos==0:
22            return ''
23        else: return ScalarFormatter.__call__(self, x, pos)
24
25class asaplotbase:
26    """
27    ASAP plotting base class based on matplotlib.
28    """
29
30    def __init__(self, rows=1, cols=0, title='', size=(8,6), buffering=False):
31        """
32        Create a new instance of the ASAPlot plotting class.
33
34        If rows < 1 then a separate call to set_panels() is required to define
35        the panel layout; refer to the doctext for set_panels().
36        """
37        self.is_dead = False
38        self.figure = Figure(figsize=size, facecolor='#ddddee')
39        self.canvas = None
40
41        self.set_title(title)
42        self.subplots = []
43        if rows > 0:
44            self.set_panels(rows, cols)
45
46        # Set matplotlib default colour sequence.
47        self.colormap = "green red black cyan magenta orange blue purple yellow pink".split()
48
49        c = asaprcParams['plotter.colours']
50        if isinstance(c,str) and len(c) > 0:
51            self.colormap = c.split()
52
53        self.lsalias = {"line":  [1,0],
54                        "dashdot": [4,2,1,2],
55                        "dashed" : [4,2,4,2],
56                        "dotted" : [1,2],
57                        "dashdotdot": [4,2,1,2,1,2],
58                        "dashdashdot": [4,2,4,2,1,2]
59                        }
60
61        styles = "line dashed dotted dashdot".split()
62        c = asaprcParams['plotter.linestyles']
63        if isinstance(c,str) and len(c) > 0:
64            styles = c.split()
65        s = []
66        for ls in styles:
67            if self.lsalias.has_key(ls):
68                s.append(self.lsalias.get(ls))
69            else:
70                s.append('-')
71        self.linestyles = s
72
73        self.color = 0;
74        self.linestyle = 0;
75        self.attributes = {}
76        self.loc = 0
77
78        self.buffering = buffering
79
80    def clear(self):
81        """
82        Delete all lines from the plot.  Line numbering will restart from 1.
83        """
84
85        for i in range(len(self.lines)):
86           self.delete(i)
87        self.axes.clear()
88        self.color = 0
89        self.lines = []
90
91    def palette(self, color, colormap=None, linestyle=0, linestyles=None):
92        if colormap:
93            if isinstance(colormap,list):
94                self.colormap = colormap
95            elif isinstance(colormap,str):
96                self.colormap = colormap.split()
97        if 0 <= color < len(self.colormap):
98            self.color = color
99        if linestyles:
100            self.linestyles = []
101            if isinstance(linestyles,list):
102                styles = linestyles
103            elif isinstance(linestyles,str):
104                styles = linestyles.split()
105            for ls in styles:
106                if self.lsalias.has_key(ls):
107                    self.linestyles.append(self.lsalias.get(ls))
108                else:
109                    self.linestyles.append(self.lsalias.get('line'))
110        if 0 <= linestyle < len(self.linestyles):
111            self.linestyle = linestyle
112
113    def delete(self, numbers=None):
114        """
115        Delete the 0-relative line number, default is to delete the last.
116        The remaining lines are NOT renumbered.
117        """
118
119        if numbers is None: numbers = [len(self.lines)-1]
120
121        if not hasattr(numbers, '__iter__'):
122            numbers = [numbers]
123
124        for number in numbers:
125            if 0 <= number < len(self.lines):
126                if self.lines[number] is not None:
127                    for line in self.lines[number]:
128                        line.set_linestyle('None')
129                        self.lines[number] = None
130        self.show()
131
132    def get_line(self):
133        """
134        Get the current default line attributes.
135        """
136        return self.attributes
137
138
139    def hist(self, x=None, y=None, msk=None, fmt=None, add=None):
140        """
141        Plot a histogram.  N.B. the x values refer to the start of the
142        histogram bin.
143
144        fmt is the line style as in plot().
145        """
146
147        if x is None:
148            if y is None: return
149            x = range(len(y))
150
151        if len(x) != len(y):
152            return
153        l2 = 2*len(x)
154        x2 = range(l2)
155        y2 = range(l2)
156        m2 = range(l2)
157
158        for i in range(l2):
159            x2[i] = x[i/2]
160            m2[i] = msk[i/2]
161
162        y2[0] = 0.0
163        for i in range(1,l2):
164            y2[i] = y[(i-1)/2]
165
166        self.plot(x2, y2, m2, fmt, add)
167
168
169    def hold(self, hold=True):
170        """
171        Buffer graphics until subsequently released.
172        """
173        self.buffering = hold
174
175
176    def legend(self, loc=None):
177        """
178        Add a legend to the plot.
179
180        Any other value for loc else disables the legend:
181             1: upper right
182             2: upper left
183             3: lower left
184             4: lower right
185             5: right
186             6: center left
187             7: center right
188             8: lower center
189             9: upper center
190            10: center
191
192        """
193        if isinstance(loc,int):
194            if 0 > loc > 10: loc = 0
195            self.loc = loc
196        self.show()
197
198
199    def plot(self, x=None, y=None, mask=None, fmt=None, add=None):
200        """
201        Plot the next line in the current frame using the current line
202        attributes.  The ASAPlot graphics window will be mapped and raised.
203
204        The argument list works a bit like the matlab plot() function.
205        """
206        if x is None:
207            if y is None: return
208            x = range(len(y))
209
210        elif y is None:
211            y = x
212            x = range(len(y))
213
214        ax = self.axes
215        s = ax.title.get_size()
216        tsize = s-(self.cols+self.rows)/2
217        ax.title.set_size(tsize)
218        origx = rcParams['xtick.labelsize']
219        origy = rcParams['ytick.labelsize']
220        if self.cols > 1:
221            xfsize = origx-(self.cols)/2
222            #rc('xtick',labelsize=xfsize)
223            ax.xaxis.label.set_size(xfsize)
224        if self.rows > 1:
225            yfsize = origy-(self.rows)/2
226            #rc('ytick',labelsize=yfsize)
227            ax.yaxis.label.set_size(yfsize)
228        if mask is None:
229            if fmt is None:
230                line = self.axes.plot(x, y)
231            else:
232                line = self.axes.plot(x, y, fmt)
233        else:
234            segments = []
235
236            mask = list(mask)
237            i = 0
238            while mask[i:].count(1):
239                i += mask[i:].index(1)
240                if mask[i:].count(0):
241                    j = i + mask[i:].index(0)
242                else:
243                    j = len(mask)
244
245                segments.append(x[i:j])
246                segments.append(y[i:j])
247
248                i = j
249
250            line = self.axes.plot(*segments)
251        #rc('xtick',labelsize=origx)
252        #rc('ytick',labelsize=origy)
253
254        # Add to an existing line?
255        if add is None or len(self.lines) < add < 0:
256            # Don't add.
257            self.lines.append(line)
258            i = len(self.lines) - 1
259        else:
260            if add == 0: add = len(self.lines)
261            i = add - 1
262            self.lines[i].extend(line)
263
264        # Set/reset attributes for the line.
265        gotcolour = False
266        for k, v in self.attributes.iteritems():
267            if k == 'color': gotcolour = True
268            for segment in self.lines[i]:
269                getattr(segment, "set_%s"%k)(v)
270
271        if not gotcolour and len(self.colormap):
272            for segment in self.lines[i]:
273                getattr(segment, "set_color")(self.colormap[self.color])
274                if len(self.colormap)  == 1:
275                    getattr(segment, "set_dashes")(self.linestyles[self.linestyle])
276            self.color += 1
277            if self.color >= len(self.colormap):
278                self.color = 0
279
280            if len(self.colormap) == 1:
281                self.linestyle += 1
282            if self.linestyle >= len(self.linestyles):
283                self.linestyle = 0
284
285        self.show()
286
287
288    def position(self):
289        """
290        Use the mouse to get a position from a graph.
291        """
292
293        def position_disable(event):
294            self.register('button_press', None)
295            print '%.4f, %.4f' % (event.xdata, event.ydata)
296
297        print 'Press any mouse button...'
298        self.register('button_press', position_disable)
299
300
301    def region(self):
302        """
303        Use the mouse to get a rectangular region from a plot.
304
305        The return value is [x0, y0, x1, y1] in world coordinates.
306        """
307
308        def region_start(event):
309            height = self.canvas.figure.bbox.height()
310            self.rect = {'fig': None, 'height': height,
311                         'x': event.x, 'y': height - event.y,
312                         'world': [event.xdata, event.ydata,
313                                   event.xdata, event.ydata]}
314            self.register('button_press', None)
315            self.register('motion_notify', region_draw)
316            self.register('button_release', region_disable)
317
318        def region_draw(event):
319            self.canvas._tkcanvas.delete(self.rect['fig'])
320            self.rect['fig'] = self.canvas._tkcanvas.create_rectangle(
321                                self.rect['x'], self.rect['y'],
322                                event.x, self.rect['height'] - event.y)
323
324        def region_disable(event):
325            self.register('motion_notify', None)
326            self.register('button_release', None)
327
328            self.canvas._tkcanvas.delete(self.rect['fig'])
329
330            self.rect['world'][2:4] = [event.xdata, event.ydata]
331            print '(%.2f, %.2f)  (%.2f, %.2f)' % (self.rect['world'][0],
332                self.rect['world'][1], self.rect['world'][2],
333                self.rect['world'][3])
334
335        self.register('button_press', region_start)
336
337        # This has to be modified to block and return the result (currently
338        # printed by region_disable) when that becomes possible in matplotlib.
339
340        return [0.0, 0.0, 0.0, 0.0]
341
342
343    def register(self, type=None, func=None):
344        """
345        Register, reregister, or deregister events of type 'button_press',
346        'button_release', or 'motion_notify'.
347
348        The specified callback function should have the following signature:
349
350            def func(event)
351
352        where event is an MplEvent instance containing the following data:
353
354            name                # Event name.
355            canvas              # FigureCanvas instance generating the event.
356            x      = None       # x position - pixels from left of canvas.
357            y      = None       # y position - pixels from bottom of canvas.
358            button = None       # Button pressed: None, 1, 2, 3.
359            key    = None       # Key pressed: None, chr(range(255)), shift,
360                                  win, or control
361            inaxes = None       # Axes instance if cursor within axes.
362            xdata  = None       # x world coordinate.
363            ydata  = None       # y world coordinate.
364
365        For example:
366
367            def mouse_move(event):
368                print event.xdata, event.ydata
369
370            a = asaplot()
371            a.register('motion_notify', mouse_move)
372
373        If func is None, the event is deregistered.
374
375        Note that in TkAgg keyboard button presses don't generate an event.
376        """
377
378        if not self.events.has_key(type): return
379
380        if func is None:
381            if self.events[type] is not None:
382                # It's not clear that this does anything.
383                self.canvas.mpl_disconnect(self.events[type])
384                self.events[type] = None
385
386                # It seems to be necessary to return events to the toolbar.
387                if type == 'motion_notify':
388                    self.canvas.mpl_connect(type + '_event',
389                        self.figmgr.toolbar.mouse_move)
390                elif type == 'button_press':
391                    self.canvas.mpl_connect(type + '_event',
392                        self.figmgr.toolbar.press)
393                elif type == 'button_release':
394                    self.canvas.mpl_connect(type + '_event',
395                        self.figmgr.toolbar.release)
396
397        else:
398            self.events[type] = self.canvas.mpl_connect(type + '_event', func)
399
400
401    def release(self):
402        """
403        Release buffered graphics.
404        """
405        self.buffering = False
406        self.show()
407
408
409    def save(self, fname=None, orientation=None, dpi=None):
410        """
411        Save the plot to a file.
412
413        fname is the name of the output file.  The image format is determined
414        from the file suffix; 'png', 'ps', and 'eps' are recognized.  If no
415        file name is specified 'yyyymmdd_hhmmss.png' is created in the current
416        directory.
417        """
418        if fname is None:
419            from datetime import datetime
420            dstr = datetime.now().strftime('%Y%m%d_%H%M%S')
421            fname = 'asap'+dstr+'.png'
422
423        d = ['png','.ps','eps']
424
425        from os.path import expandvars
426        fname = expandvars(fname)
427
428        if fname[-3:].lower() in d:
429            try:
430                if fname[-3:].lower() == ".ps":
431                    from matplotlib import __version__ as mv
432                    w = self.figure.figwidth.get()
433                    h = self.figure.figheight.get()
434
435                    if orientation is None:
436                        # auto oriented
437                        if w > h:
438                            orientation = 'landscape'
439                        else:
440                            orientation = 'portrait'
441                    a4w = 8.25
442                    a4h = 11.25
443                    ds = None
444                    if orientation == 'landscape':
445                        ds = min(a4h/w,a4w/h)
446                    else:
447                        ds = min(a4w/w,a4h/h)
448                    ow = ds * w
449                    oh = ds * h
450                    self.figure.set_figsize_inches((ow,oh))
451                    self.figure.savefig(fname, orientation=orientation,
452                                        papertype="a4")
453                    # reset the figure size
454                    self.figure.set_figsize_inches((w,h))
455                    print 'Written file %s' % (fname)
456                else:
457                    if dpi is None:
458                        dpi =150
459                    self.figure.savefig(fname,dpi=dpi)
460                    print 'Written file %s' % (fname)
461            except IOError, msg:
462                print 'Failed to save %s: Error msg was\n\n%s' % (fname, err)
463                return
464        else:
465            print "Invalid image type. Valid types are:"
466            print "'ps', 'eps', 'png'"
467
468
469    def set_axes(self, what=None, *args, **kwargs):
470        """
471        Set attributes for the axes by calling the relevant Axes.set_*()
472        method.  Colour translation is done as described in the doctext
473        for palette().
474        """
475
476        if what is None: return
477        if what[-6:] == 'colour': what = what[:-6] + 'color'
478
479        newargs = {}
480
481        for k, v in kwargs.iteritems():
482            k = k.lower()
483            if k == 'colour': k = 'color'
484            newargs[k] = v
485
486        getattr(self.axes, "set_%s"%what)(*args, **newargs)
487
488        self.show()
489
490
491    def set_figure(self, what=None, *args, **kwargs):
492        """
493        Set attributes for the figure by calling the relevant Figure.set_*()
494        method.  Colour translation is done as described in the doctext
495        for palette().
496        """
497
498        if what is None: return
499        if what[-6:] == 'colour': what = what[:-6] + 'color'
500        #if what[-5:] == 'color' and len(args):
501        #    args = (get_colour(args[0]),)
502
503        newargs = {}
504        for k, v in kwargs.iteritems():
505            k = k.lower()
506            if k == 'colour': k = 'color'
507            newargs[k] = v
508
509        getattr(self.figure, "set_%s"%what)(*args, **newargs)
510        self.show()
511
512
513    def set_limits(self, xlim=None, ylim=None):
514        """
515        Set x-, and y-limits for each subplot.
516
517        xlim = [xmin, xmax] as in axes.set_xlim().
518        ylim = [ymin, ymax] as in axes.set_ylim().
519        """
520        for s in self.subplots:
521            self.axes  = s['axes']
522            self.lines = s['lines']
523            oldxlim =  list(self.axes.get_xlim())
524            oldylim =  list(self.axes.get_ylim())
525            if xlim is not None:
526                for i in range(len(xlim)):
527                    if xlim[i] is not None:
528                        oldxlim[i] = xlim[i]
529            if ylim is not None:
530                for i in range(len(ylim)):
531                    if ylim[i] is not None:
532                        oldylim[i] = ylim[i]
533            self.axes.set_xlim(oldxlim)
534            self.axes.set_ylim(oldylim)
535        return
536
537
538    def set_line(self, number=None, **kwargs):
539        """
540        Set attributes for the specified line, or else the next line(s)
541        to be plotted.
542
543        number is the 0-relative number of a line that has already been
544        plotted.  If no such line exists, attributes are recorded and used
545        for the next line(s) to be plotted.
546
547        Keyword arguments specify Line2D attributes, e.g. color='r'.  Do
548
549            import matplotlib
550            help(matplotlib.lines)
551
552        The set_* methods of class Line2D define the attribute names and
553        values.  For non-US usage, "colour" is recognized as synonymous with
554        "color".
555
556        Set the value to None to delete an attribute.
557
558        Colour translation is done as described in the doctext for palette().
559        """
560
561        redraw = False
562        for k, v in kwargs.iteritems():
563            k = k.lower()
564            if k == 'colour': k = 'color'
565
566            if 0 <= number < len(self.lines):
567                if self.lines[number] is not None:
568                    for line in self.lines[number]:
569                        getattr(line, "set_%s"%k)(v)
570                    redraw = True
571            else:
572                if v is None:
573                    del self.attributes[k]
574                else:
575                    self.attributes[k] = v
576
577        if redraw: self.show()
578
579
580    def set_panels(self, rows=1, cols=0, n=-1, nplots=-1, ganged=True):
581        """
582        Set the panel layout.
583
584        rows and cols, if cols != 0, specify the number of rows and columns in
585        a regular layout.   (Indexing of these panels in matplotlib is row-
586        major, i.e. column varies fastest.)
587
588        cols == 0 is interpreted as a retangular layout that accomodates
589        'rows' panels, e.g. rows == 6, cols == 0 is equivalent to
590        rows == 2, cols == 3.
591
592        0 <= n < rows*cols is interpreted as the 0-relative panel number in
593        the configuration specified by rows and cols to be added to the
594        current figure as its next 0-relative panel number (i).  This allows
595        non-regular panel layouts to be constructed via multiple calls.  Any
596        other value of n clears the plot and produces a rectangular array of
597        empty panels.  The number of these may be limited by nplots.
598        """
599        if n < 0 and len(self.subplots):
600            self.figure.clear()
601            self.set_title()
602
603        if rows < 1: rows = 1
604
605        if cols <= 0:
606            i = int(sqrt(rows))
607            if i*i < rows: i += 1
608            cols = i
609
610            if i*(i-1) >= rows: i -= 1
611            rows = i
612
613        if 0 <= n < rows*cols:
614            i = len(self.subplots)
615            self.subplots.append({})
616
617            self.subplots[i]['axes']  = self.figure.add_subplot(rows,
618                                            cols, n+1)
619            self.subplots[i]['lines'] = []
620
621            if i == 0: self.subplot(0)
622
623            self.rows = 0
624            self.cols = 0
625
626        else:
627            self.subplots = []
628
629            if nplots < 1 or rows*cols < nplots:
630                nplots = rows*cols
631            if ganged:
632                hsp,wsp = None,None
633                if rows > 1: hsp = 0.0001
634                if cols > 1: wsp = 0.0001
635                self.figure.subplots_adjust(wspace=wsp,hspace=hsp)
636            for i in range(nplots):
637                self.subplots.append({})
638                self.subplots[i]['axes'] = self.figure.add_subplot(rows,
639                                                cols, i+1)
640                self.subplots[i]['lines'] = []
641
642                if ganged:
643                    # Suppress tick labelling for interior subplots.
644                    if i <= (rows-1)*cols - 1:
645                        if i+cols < nplots:
646                            # Suppress x-labels for frames width
647                            # adjacent frames
648                            self.subplots[i]['axes'].xaxis.set_major_locator(NullLocator())
649                            self.subplots[i]['axes'].xaxis.label.set_visible(False)
650                    if i%cols:
651                        # Suppress y-labels for frames not in the left column.
652                        for tick in self.subplots[i]['axes'].yaxis.majorTicks:
653                            tick.label1On = False
654                        self.subplots[i]['axes'].yaxis.label.set_visible(False)
655                    # disable the first tick of [1:ncol-1] of the last row
656                    if (nplots-cols) < i <= nplots-1:
657                        self.subplots[i]['axes'].xaxis.set_major_formatter(MyFormatter())
658                self.rows = rows
659                self.cols = cols
660            self.subplot(0)
661
662    def set_title(self, title=None):
663        """
664        Set the title of the plot window.  Use the previous title if title is
665        omitted.
666        """
667        if title is not None:
668            self.title = title
669
670        self.figure.text(0.5, 0.95, self.title, horizontalalignment='center')
671
672
673    def show(self):
674        """
675        Show graphics dependent on the current buffering state.
676        """
677        if not self.buffering:
678            if self.loc is not None:
679                for j in range(len(self.subplots)):
680                    lines  = []
681                    labels = []
682                    i = 0
683                    for line in self.subplots[j]['lines']:
684                        i += 1
685                        if line is not None:
686                            lines.append(line[0])
687                            lbl = line[0].get_label()
688                            if lbl == '':
689                                lbl = str(i)
690                            labels.append(lbl)
691
692                    if len(lines):
693                        self.subplots[j]['axes'].legend(tuple(lines),
694                                                        tuple(labels),
695                                                        self.loc)
696                    else:
697                        self.subplots[j]['axes'].legend((' '))
698
699
700    def subplot(self, i=None, inc=None):
701        """
702        Set the subplot to the 0-relative panel number as defined by one or
703        more invokations of set_panels().
704        """
705        l = len(self.subplots)
706        if l:
707            if i is not None:
708                self.i = i
709
710            if inc is not None:
711                self.i += inc
712
713            self.i %= l
714            self.axes  = self.subplots[self.i]['axes']
715            self.lines = self.subplots[self.i]['lines']
716
717
718    def text(self, *args, **kwargs):
719        """
720        Add text to the figure.
721        """
722        self.figure.text(*args, **kwargs)
723        self.show()
Note: See TracBrowser for help on using the repository browser.