Source code for pki.util

# Authors:
#     Endi S. Dewata <edewata@redhat.com>
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the Lesser GNU General Public License as published by
# the Free Software Foundation; either version 3 of the License or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
#  along with this program; if not, write to the Free Software Foundation,
# Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
#
# Copyright (C) 2013 Red Hat, Inc.
# All rights reserved.
#
"""
Module containing utility functions and classes for the Dogtag python code
"""


from __future__ import absolute_import
import fileinput
import functools
import getpass
import io
import logging
import operator
import os
import re
import shutil
from shutil import Error
try:
    from shutil import WindowsError  # pylint: disable=E0611
except ImportError:
    WindowsError = None

import six
from six.moves import input   # pylint: disable=W0622,F0401
import subprocess

DEFAULT_PKI_ENV_LIST = [
    '/usr/share/pki/etc/pki.conf',
    '/etc/pki/pki.conf',
]

logger = logging.getLogger(__name__)


[docs] def replace_params(line, params=None): """ Replace all occurrences of [param] in the line with the value of the parameter. """ if not params: return line # find the first parameter in the line begin = line.find('[') # repeat while there are parameters in the line while begin >= 0: # find the end of the parameter end = line.find(']', begin + 1) # if the end not is found not found, don't do anything if end < 0: return line # get parameter name name = line[begin + 1:end] try: # get parameter value as string value = str(params[name]) # replace parameter with value, keep the rest of the line line = line[0:begin] + value + line[end + 1:] # calculate the new end position end = begin + len(value) except KeyError: # undefined parameter, skip logger.warning('Ignoring %s parameter', line[begin:end + 1]) # find the next parameter in the remainder of the line begin = line.find('[', end) return line
[docs] def makedirs( path, mode=0o777, exist_ok=False, uid=-1, gid=-1, force=False): parent = os.path.dirname(path) if not os.path.exists(parent): makedirs(parent, mode=mode, exist_ok=exist_ok, uid=uid, gid=gid, force=force) if force and os.path.exists(path): logger.warning('Directory already exists: %s', path) return logger.debug('Command: mkdir %s', path) os.makedirs(path, mode=mode, exist_ok=exist_ok) if os.geteuid() == 0: os.chown(path, uid, gid)
[docs] def copy( source, dest, uid=-1, gid=-1, dir_mode=None, file_mode=None, force=False): """ Copy a file or a folder and its contents. """ # remove trailing slashes if source[-1] == '/': source = source[:-1] if dest[-1] == '/': dest = dest[:-1] sourceparent = os.path.dirname(source) destparent = os.path.dirname(dest) if not os.path.exists(destparent): # if parent directory doesn't exist, create it first copydirs(sourceparent, destparent, uid=uid, gid=gid, mode=dir_mode, force=force) if os.path.isfile(source): # if it's a file, copy the file copyfile(source, dest, uid=uid, gid=gid, mode=file_mode, force=force) return # it's a directory tree, go through each directory in the tree for sourcepath, _, filenames in os.walk(source): relpath = sourcepath[len(source):] destpath = dest + relpath if destpath == '': destpath = '/' # copy the directory itself copydirs(sourcepath, destpath, uid=uid, gid=gid, mode=dir_mode, force=force) # copy the contents for filename in filenames: sourcefile = os.path.join(sourcepath, filename) targetfile = os.path.join(destpath, filename) copyfile(sourcefile, targetfile, uid=uid, gid=gid, mode=file_mode, force=force)
[docs] def copyfile(source, dest, params=None, uid=None, gid=None, mode=None, force=False): """ Copy a file or link while preserving its attributes. """ logger.debug('Command: cp %s %s', source, dest) # if dest already exists and not overwriting, do nothing if os.path.exists(dest) and not force: logger.warning('File already exists: %s', dest) return # if source is a link, copy the link if os.path.islink(source): target = os.readlink(source) os.symlink(target, dest) stat = os.lstat(source) if uid is None: uid = stat.st_uid if gid is None: gid = stat.st_gid if os.geteuid() == 0: os.lchown(dest, uid, gid) return # source is a file stat = os.stat(source) if not params: # if no substitution is required, copy the file shutil.copyfile(source, dest) os.utime(dest, (stat.st_atime, stat.st_mtime)) else: # otherwise, customize the file if params is None: params = {} with open(dest, 'w', encoding='utf-8') as f: for line in fileinput.FileInput(source): line = replace_params(line, params) f.write(line) if uid is None: uid = stat.st_uid if gid is None: gid = stat.st_gid if os.geteuid() == 0: os.chown(dest, uid, gid) if mode is None: mode = stat.st_mode os.chmod(dest, mode)
[docs] def copydirs(source, dest, uid=-1, gid=-1, mode=None, force=False): """ Copy a folder and its parents (without the contents) while preserving their attributes. """ destparent = os.path.dirname(dest) if not os.path.exists(destparent): sourceparent = os.path.dirname(source) copydirs(sourceparent, destparent, uid=uid, gid=gid, mode=mode, force=force) logger.debug('Command: mkdir %s', dest) if force and os.path.exists(dest): logger.warning('Directory already exists: %s', dest) return os.mkdir(dest) stat = os.stat(source) if uid == -1: uid = stat.st_uid if gid == -1: gid = stat.st_gid os.utime(dest, (stat.st_atime, stat.st_mtime)) if os.geteuid() == 0: os.chown(dest, uid, gid) if mode is None: mode = stat.st_mode os.chmod(dest, mode)
[docs] def chown(path, uid, gid): """ Change ownership of a file, link, or folder recursively. """ if os.path.islink(path): os.lchown(path, uid, gid) else: os.chown(path, uid, gid) if not os.path.isdir(path): return for item in os.listdir(path): itempath = os.path.join(path, item) chown(itempath, uid, gid)
[docs] def chmod(path, mode): """ Change permissions of a file, link, or folder recursively. """ os.chmod(path, mode) if not os.path.isdir(path): return for item in os.listdir(path): itempath = os.path.join(path, item) chmod(itempath, mode)
[docs] def remove(path, force=False): logger.debug('Command: rm -rf %s', path) if force and not os.path.exists(path): logger.warning('File not found: %s', path) return os.remove(path)
[docs] def rmtree(path, force=False): logger.debug('Command: rm -rf %s', path) if force and not os.path.exists(path): logger.warning('Directory not found: %s', path) return shutil.rmtree(path)
[docs] def customize_file(input_file, output_file, params): """ Customize a file with specified parameters. """ with open(input_file, encoding='utf-8') as infile, \ open(output_file, 'w', encoding='utf-8') as outfile: for line in infile: for src, target in params.items(): line = line.replace(src, target) outfile.write(line)
[docs] def load_properties(filename, properties): with open(filename, encoding='utf-8') as f: lines = f.read().splitlines() name = None multi_line = False for index, line in enumerate(lines): if multi_line: # append line to previous property value = properties[name] value = value + line else: # parse line for new property line = line.lstrip() if not line or line.startswith('#'): continue parts = line.split('=', 1) if len(parts) < 2: raise Exception('Missing delimiter in %s line %d' % (filename, index + 1)) name = parts[0].rstrip() value = parts[1].lstrip() # check if the value is multi-line if value.endswith('\\'): value = value[:-1] multi_line = True else: value = value.rstrip() multi_line = False # store value in properties properties[name] = value
[docs] def store_properties(filename, properties): sorted_props = sorted(properties.items(), key=operator.itemgetter(0)) with io.open(filename, 'w', encoding='utf-8') as f: for name, value in sorted_props: if value is None: # write None as empty value f.write('{0}=\n'.format(name)) elif isinstance(value, six.string_types): f.write('{0}={1}\n'.format(name, value)) elif isinstance(value, six.integer_types): f.write('{0}={1:d}\n'.format(name, value)) else: raise TypeError((name, value, type(value)))
[docs] def set_property(properties, name, value): if value is None: # no change return if value: # non-empty value updates the property properties[name] = value else: # empty value removes the property properties.pop(name, None)
[docs] def copytree(src, dst, symlinks=False, ignore=None): """ Recursively copy a directory tree using copy2(). PATCH: This code was copied from 'shutil.py' and patched to allow 'The destination directory to already exist.' If exception(s) occur, an Error is raised with a list of reasons. If the optional symlinks flag is true, symbolic links in the source tree result in symbolic links in the destination tree; if it is false, the contents of the files pointed to by symbolic links are copied. The optional ignore argument is a callable. If given, it is called with the `src` parameter, which is the directory being visited by copytree(), and `names` which is the list of `src` contents, as returned by os.listdir(): callable(src, names) -> ignored_names Since copytree() is called recursively, the callable will be called once for each directory that is copied. It returns a list of names relative to the `src` directory that should not be copied. Consider this example code rather than the ultimate tool. """ names = os.listdir(src) if ignore is not None: ignored_names = ignore(src, names) else: ignored_names = set() # PATCH: ONLY execute 'os.makedirs(dst)' if the top-level # destination directory does NOT exist! if not os.path.exists(dst): os.makedirs(dst) errors = [] for name in names: if name in ignored_names: continue srcname = os.path.join(src, name) dstname = os.path.join(dst, name) try: if symlinks and os.path.islink(srcname): linkto = os.readlink(srcname) os.symlink(linkto, dstname) elif os.path.isdir(srcname): copytree(srcname, dstname, symlinks, ignore) else: # Will raise a SpecialFileError for unsupported file types shutil.copy2(srcname, dstname) # catch the Error from the recursive copytree so that we can # continue with other files except Error as err: errors.extend(err.args[0]) except EnvironmentError as why: errors.append((srcname, dstname, str(why))) try: shutil.copystat(src, dst) except OSError as why: if WindowsError is not None and isinstance(why, WindowsError): # Copying file access times may fail on Windows pass else: errors.extend((src, dst, str(why))) if errors: raise Error(errors)
[docs] def read_environment_files(env_file_list=None): if env_file_list is None: env_file_list = DEFAULT_PKI_ENV_LIST file_command = ' && '.join( 'source {}'.format(env_file) for env_file in env_file_list) file_command += ' && env' command = [ 'bash', '-c', file_command ] env_vals = subprocess.check_output(command).decode('utf-8').split('\n') for env_val in env_vals: (key, _, value) = env_val.partition("=") if not key.strip() or key == '_': continue os.environ[key] = value
[docs] def read_text(message, options=None, default=None, delimiter=':', case_sensitive=True, password=False, required=False): """ Get an input from the user. This is used, for example, in pkispawn and pkidestroy to obtain user input. :param message: prompt to display to the user :type message: str :param options: list of possible inputs by the user. :type options: list :param default: default value of parameter being prompted. :type default: str :param delimiter: delimiter to be used at the end of the prompt. :type delimiter: str :param case_sensitive: Allow input to be case sensitive. :type case_sensitive: boolean -- True/False :param password: Input is a password. Don't show the value. :type password: boolean -- True/False :param required: Input must be non-empty. :type required: boolean -- True/False :returns: str -- value obtained from user input. """ if default is not None: if len(default) == 0: message = message + ' []' elif password: message = message + ' [********]' else: message = message + ' [' + default + ']' message = message + delimiter + ' ' if options and not case_sensitive: options = list(options) for i in range(len(options)): options[i] = options[i].lower() # normalize options while True: if password: value = getpass.getpass(message) else: value = input(message) if not value: # empty value if not required: return default if default: return default continue value = value.strip() if not value: # blank value if not required: return value continue if options: # non-empty options if case_sensitive and value in options: return value if not case_sensitive and value.lower() in options: return value continue return value
[docs] @functools.total_ordering class Version(object): def __init__(self, obj): if isinstance(obj, six.string_types): # parse <major>.<minor>.<patch>[<suffix>] match = re.match(r'^(\d+)\.(\d+)\.(\d+)', obj) if match is None: raise Exception('Unable to parse version number: %s' % obj) self.major = int(match.group(1)) self.minor = int(match.group(2)) self.patch = int(match.group(3)) elif isinstance(obj, Version): self.major = obj.major self.minor = obj.minor self.patch = obj.patch else: raise Exception('Unsupported version type: %s' % type(obj)) # release is ignored in comparisons def __eq__(self, other): return (self.major == other.major and self.minor == other.minor and self.patch == other.patch) def __ne__(self, other): return not self.__eq__(other) def __lt__(self, other): if self.major < other.major: return True if self.major == other.major and self.minor < other.minor: return True if (self.major == other.major and self.minor == other.minor and self.patch < other.patch): return True return False def __gt__(self, other): return not self.__lt__(other) and not self.__eq__(other) # not hashable __hash__ = None def __repr__(self): return '%d.%d.%d' % (self.major, self.minor, self.patch)