Source code for aiida_vasp.utils.bands

"""
Utils for bands structures.

---------------------------
Utilities for working with band structures. Currently this is legacy and will be
rewritten or moved.
"""
# pylint: disable=import-outside-toplevel
try:
    import matplotlib
    matplotlib.use('TKAgg')
    from matplotlib import pyplot as plt
except ImportError as no_matplotlib:
    raise ImportError('Error: matplotlib must be ' + 'installed to use this functionality') from no_matplotlib


[docs] def get_bs_dims(bands_array): """ Get the dimensions from the bands array of a BandsData node. :param numpy.array bands_array: an array with bands as stored in an array.bands data node :return: a tuple containing num_bands, num_kp, num_spins. if the array is only 2d, num_spins = 0 :rtype tuple: """ bshape = bands_array.shape nbd = nkp = nsp = 0 if len(bshape) == 2: nbd = bshape[1] nkp = bshape[0] elif len(bshape) == 3: nbd = bshape[2] nkp = bshape[1] nsp = bshape[0] return nbd, nkp, nsp
[docs] def get_kp_labels(bands_node, kpoints_node=None): """ Get Kpoint labels with their x-positions in matplotlib compatible format. A KpointsData node can optionally be given to fall back to if no labels are found on the BandsData node. The caller is responsible for ensuring the nodes match. This should be the case if you take the kpoints from the input and the bands from the output of a calculation node. :param BandsData bands_node: The BandsData node will be searched labels first :param KpointsData kpoints_node: The optional KpointsData node will be searched only if no labels are present on the BandsData node. No consistency checks are performed. :return: (kpx, kpl), the x-coordinates and text labels :rtype: tuple(list[int], list[unicode]) :raises AttributeError: if neither of the given nodes have a labels attribute """ kplabs = None kpx = [] kpl = [] try: kplabs = bands_node.labels except AttributeError as err: if kpoints_node: kplabs = kpoints_node.labels else: raise err if kplabs: kpx = [i[0] for i in kplabs] kpl = [i[1] for i in kplabs] for i, kpoints in enumerate(kpl): if kpoints == 'G': kpl[i] = r'$\Gamma$' return kpx, kpl
[docs] def get_efermi(calc): """Get the fermi energy from a finished calculation.""" efermi = None if calc: p_res = calc.get_outputs_dict().get('results') efermi = p_res and p_res.get_dict().get('efermi') return efermi
[docs] def get_kp_node(calc): kpoints_node = None if calc: kpoints_node = calc.get_inputs_dict().get('kpoints') return kpoints_node
[docs] def plot_bstr(bands_node, kpoints_node=None, title=None, efermi=None, use_parent_calc=False, **kwargs): """ Use matplotlib to plot the bands stored in a BandsData node. A KpointsData node can optionally be given as a fallback for kpoint labels. The caller is responsible for giving a node with matching labels (as in they are in/out nodes of the same calculation). :param BandsData bands_node: The BandsData node will be searched labels first :param KpointsData kpoints_node: The optional KpointsData node will be searched only if no labels are present on the BandsData node. No consistency checks are performed. :return: the matplotlib figure containing the plot """ fig = plt.figure() title = title or f'Band Structure (pk={bands_node.pk})' bands = bands_node.get_bands() _, nkp, _ = get_bs_dims(bands) plot_bands(bands_node, **kwargs) parent_calc = None if use_parent_calc: inputs = bands_node.get_inputs() parent_calc = inputs[0] if inputs else None efermi = get_efermi(parent_calc) kpoints_node = get_kp_node(parent_calc) if efermi: plt.hlines(efermi, plt.xlim()[0], nkp - 1, linestyles='dashed') plt.yticks(list(plt.yticks()[0]) + [efermi], [str(line) for line in plt.yticks()[0]] + [r'$E_{fermi}$']) try: kpx, kpl = get_kp_labels(bands_node, kpoints_node) plt.xticks(kpx, kpl) plt.vlines(kpx, plt.ylim()[0], plt.ylim()[1]) except Exception: # pylint: disable=broad-except pass plt.ylabel('Dispersion') plt.suptitle(title) return fig
[docs] def plot_bands(bands_node, **kwargs): """Plot a bandstructure node using matplotlib.""" import numpy as np bands = bands_node.get_bands() nbands, nkp, nspin = get_bs_dims(bands) if nspin > 0: allbands = np.empty((nkp, nbands * nspin)) for i in range(nspin): allbands[:, i * nbands:(i + 1) * nbands] = bands[i] bands = allbands if 'colors' in kwargs: import itertools colors = itertools.cycle(kwargs.pop('colors')) for b_idx in range(bands.shape[1]): plt.plot(bands[:, b_idx], color=colors.next(), **kwargs) # pylint: disable=no-member, not-callable else: plt.plot(bands, **kwargs)