Source code for openmdao.util.publickey

"""
Support for generation, use, and storage of public/private key pairs.
The :func:`pk_encrypt`, :func:`pk_decrypt`, :func:`pk_sign`, and
:func:`pk_verify` functions provide a thin interface over
:class:`Crypto.PublicKey.RSA` methods for easier use and to work around some
issues found with some keys read from ssh ``id_rsa`` files.
"""

import base64
import cPickle
import getpass
import logging
import os.path
import socket
import sys
import threading

from Crypto.PublicKey import RSA
from Crypto.Random import get_random_bytes
from Crypto.Util.number import bytes_to_long

if sys.platform == 'win32':  #pragma no cover
    try:
        import win32api
        import win32security
        import ntsecuritycon
    except ImportError:
        HAVE_PYWIN32 = False
    else:
        HAVE_PYWIN32 = True
else:
    HAVE_PYWIN32 = False

from openmdao.util.log import NullLogger

# Cache of client key pairs indexed by user.
_KEY_CACHE = {}
_KEY_CACHE_LOCK = threading.Lock()


[docs]def get_key_pair(user_host, logger=None, overwrite_cache=False, ignore_ssh=False): """ Returns RSA key containing both public and private keys for the user identified in `user_host`. This can be an expensive operation, so we avoid generating a new key pair whenever possible. If ``~/.ssh/id_rsa`` exists and is private, that key is returned. user_host: string Format ``user@host``. logger: :class:`logging.Logger` Used for debug messages. overwrite_cache: bool If True, a new key is generated and forced into the cache of existing known keys. Used for testing. ignore_ssh: bool If True, ignore any existing ssh id_rsa key file. Used for testing. .. note:: To avoid unnecessary key generation, the public/private key pair for the current user is stored in the private file ``~/.openmdao/keys``. On Windows this requires the pywin32 extension. Also, the public key is stored in ssh form in ``~/.openmdao/id_rsa.pub``. """ logger = logger or NullLogger() with _KEY_CACHE_LOCK: if overwrite_cache: key_pair = _generate(user_host, logger) _KEY_CACHE[user_host] = key_pair return key_pair # Look in previously generated keys. try: key_pair = _KEY_CACHE[user_host] except KeyError: # If key for current user (typical), check filesystem. # TODO: file lock to protect from separate processes. user, host = user_host.split('@') if user == getpass.getuser(): current_user = True key_pair = None # Try to re-use SSH key. Exceptions should *never* be exercised! if not ignore_ssh: id_rsa = \ os.path.expanduser(os.path.join('~', '.ssh', 'id_rsa')) if is_private(id_rsa): try: with open(id_rsa, 'r') as inp: key_pair = RSA.importKey(inp.read()) except Exception as exc: #pragma no cover logger.warning('ssh id_rsa import: %r', exc) else: generate = False else: #pragma no cover logger.warning('Ignoring insecure ssh id_rsa.') if key_pair is None: # Look for OpenMDAO key. key_file = \ os.path.expanduser(os.path.join('~', '.openmdao', 'keys')) if is_private(key_file): try: with open(key_file, 'rb') as inp: key_pair = cPickle.load(inp) except Exception: generate = True else: generate = False else: logger.warning('Insecure keyfile! Regenerating keys.') os.remove(key_file) generate = True # Difficult to run test as non-current user. else: #pragma no cover current_user = False generate = True if generate: key_pair = _generate(user_host, logger) if current_user: key_dir = os.path.dirname(key_file) if not os.path.exists(key_dir): os.mkdir(key_dir) # Save key pair in protected file. if sys.platform == 'win32' and not HAVE_PYWIN32: #pragma no cover logger.debug('No pywin32, not saving keyfile') else: make_private(key_dir) # Private while writing keyfile. with open(key_file, 'wb') as out: cPickle.dump(key_pair, out, cPickle.HIGHEST_PROTOCOL) try: make_private(key_file) # Hard to cause (recoverable) error here. except Exception: #pragma no cover os.remove(key_file) # Remove unsecured file. raise # Save public key in ssh form. users = {user_host: key_pair.publickey()} filename = os.path.join(key_dir, 'id_rsa.pub') write_authorized_keys(users, filename, logger) _KEY_CACHE[user_host] = key_pair return key_pair
def _generate(user_host, logger): """ Return new key. """ logger.debug('generating public key for %r...', user_host) if sys.platform == 'win32' and not HAVE_PYWIN32: #pragma no cover strength = 1024 # Much quicker to generate. else: strength = 2048 key_pair = RSA.generate(strength, get_random_bytes) logger.debug(' done') return key_pair
[docs]def pk_encrypt(data, public_key): """ Return list of chunks of `data` encrypted by `public_key`. data: string The message to be encrypted. public_key: :class:`Crypto.PublicKey.RSA` Public portion of key pair. """ # Normally we would use 8 rather than 16 here, but for some reason at least # some keys read from ssh id_rsa files don't work correctly with 8. chunk_size = public_key.size() / 16 chunks = [] while data: chunks.append(public_key.encrypt(data[:chunk_size], '')) data = data[chunk_size:] return chunks
[docs]def pk_decrypt(encrypted, private_key): """ Return `encrypted` decrypted by `private_key` as a string. encrypted: list Chunks of encrypted data returned by :func:`pk_encrypt`. private_key: :class:`Crypto.PublicKey.RSA` Private portion of key pair. """ data = '' for chunk in encrypted: data += private_key.decrypt(chunk) return data
[docs]def pk_sign(hashed, private_key): """ Return signature for `hashed` using `private_key`. hashed: string A hash value of the data to be signed. private_key: :class:`Crypto.PublicKey.RSA` Private portion of key pair. """ # Normally we would just do: # return private_key.sign(hashed, '') # But that fails for at least some keys from ssh id_rsa files. # Instead, use the 'slowmath' method: c = bytes_to_long(hashed) m = pow(c, private_key.d, private_key.n) return (m,)
[docs]def pk_verify(hashed, signature, public_key): """ Verify `hashed` based on `signature` and `public_key`. hashed: string A hash for the data that is signed. signature: tuple Value returned by :func:`pk_sign`. public_key: :class:`Crypto.PublicKey.RSA` Public portion of key pair. """ return public_key.verify(hashed, signature)
[docs]def is_private(path): """ Return True if `path` is accessible only by 'owner'. path: string Path to file or directory to check. .. note:: On Windows this requires the pywin32 extension. """ if not os.path.exists(path): return True # Nonexistent file is secure ;-) if sys.platform == 'win32': #pragma no cover if not HAVE_PYWIN32: return False # No way to know. # Find the SIDs for user and system. username = win32api.GetUserName() # Map Cygwin 'root' to 'Administrator'. Typically these are intended # to be identical, but /etc/passwd might configure them differently. if username == 'root': username = 'Administrator' user, domain, type = win32security.LookupAccountName('', username) system, domain, type = win32security.LookupAccountName('', 'System') # Find the DACL part of the Security Descriptor for the file sd = win32security.GetFileSecurity(path, win32security.DACL_SECURITY_INFORMATION) dacl = sd.GetSecurityDescriptorDacl() # Verify the DACL contains just the two entries we expect. count = dacl.GetAceCount() if count != 2: return False for i in range(count): ace = dacl.GetAce(i) if ace[2] != user and ace[2] != system: return False return True else: return (os.stat(path).st_mode & 0077) == 0
[docs]def make_private(path): """ Make `path` accessible only by 'owner'. path: string Path to file or directory to be made private. .. note:: On Windows this requires the pywin32 extension. """ if sys.platform == 'win32': #pragma no cover if not HAVE_PYWIN32: raise ImportError('No pywin32') # Find the SIDs for user and system. username = win32api.GetUserName() # Map Cygwin 'root' to 'Administrator'. Typically these are intended # to be identical, but /etc/passwd might configure them differently. if username == 'root': username = 'Administrator' user, domain, type = win32security.LookupAccountName('', username) system, domain, type = win32security.LookupAccountName('', 'System') # Find the DACL part of the Security Descriptor for the file sd = win32security.GetFileSecurity(path, win32security.DACL_SECURITY_INFORMATION) # Create a blank DACL and add the ACEs we want. dacl = win32security.ACL() dacl.AddAccessAllowedAce(win32security.ACL_REVISION, ntsecuritycon.FILE_ALL_ACCESS, user) dacl.AddAccessAllowedAce(win32security.ACL_REVISION, ntsecuritycon.FILE_ALL_ACCESS, system) # Put our new DACL into the Security Descriptor and update the file # with the updated SD. sd.SetSecurityDescriptorDacl(1, dacl, 0) win32security.SetFileSecurity(path, win32security.DACL_SECURITY_INFORMATION, sd) else: mode = 0700 if os.path.isdir(path) else 0600 os.chmod(path, mode) # Read/Write/Execute
[docs]def encode_public_key(key): """ Return base64 text representation of public key `key`. key: public key Public part of key pair. """ # Just being defensive, this should never happen. if key.has_private(): #pragma no cover key = key.publickey() return base64.b64encode(cPickle.dumps(key, cPickle.HIGHEST_PROTOCOL))
[docs]def decode_public_key(text): """ Return public key from text representation. text: string base64 encoded key data. """ return cPickle.loads(base64.b64decode(text))
[docs]def read_authorized_keys(filename=None, logger=None): """ Return dictionary of public keys, indexed by user, read from `filename`. The file must be in ssh format, and only RSA keys are processed. If the file is not private, then no keys are returned. filename: string File to read from. The default is ``~/.ssh/authorized_keys``. logger: :class:`logging.Logger` Used for log messages. """ if not filename: filename = \ os.path.expanduser(os.path.join('~', '.ssh', 'authorized_keys')) logger = logger or NullLogger() if not os.path.exists(filename): raise RuntimeError('%r does not exist' % filename) if not is_private(filename): if sys.platform != 'win32' or HAVE_PYWIN32: raise RuntimeError('%r is not private' % filename) else: #pragma no cover logger.warning('Allowed users file %r is not private', filename) errors = 0 keys = {} with open(filename, 'r') as inp: for line in inp: line = line.rstrip() sharp = line.find('#') if sharp >= 0: line = line[:sharp] if not line: continue key_type, blank, rest = line.partition(' ') if key_type != 'ssh-rsa': logger.error('unsupported key type: %r', key_type) errors += 1 continue key_data, blank, user_host = rest.partition(' ') if not key_data: logger.error('bad line (missing key data):') logger.error(line) errors += 1 continue try: user, host = user_host.split('@') except ValueError: logger.error('bad line (require user@host):') logger.error(line) errors += 1 continue logger.debug('user %r, host %r', user, host) try: ip_addr = socket.gethostbyname(host) except socket.gaierror: logger.error('unknown host %r', host) logger.error(line) errors += 1 continue data = base64.b64decode(key_data) start = 0 name_len = _longint(data, start, 4) start += 4 name = data[start:start+name_len] if name != 'ssh-rsa': logger.error('name error: %r vs. ssh-rsa', name) logger.error(line) errors += 1 continue start += name_len e_len = _longint(data, start, 4) start += 4 e = _longint(data, start, e_len) start += e_len n_len = _longint(data, start, 4) start += 4 n = _longint(data, start, n_len) start += n_len if start != len(data): logger.error('length error: %d vs. %d', start, len(data)) logger.error(line) errors += 1 continue try: pubkey = RSA.construct((n, e)) except Exception as exc: logger.error('key construct error: %r', exc) errors += 1 else: keys[user_host] = pubkey if errors: raise RuntimeError('%d errors in %r, check log for details' % (errors, filename)) return keys
def _longint(buf, start, length): """ Return long value from binary string. """ value = long(0) for i in range(length): value = (value << 8) + ord(buf[start]) start += 1 return value
[docs]def write_authorized_keys(allowed_users, filename, logger=None): """ Write `allowed_users` to `filename` in ssh format. The file will be made private if supported on this platform. allowed_users: dict Dictionary of public keys indexed by user. filename: string File to write to. logger: :class:`logging.Logger` Used for log messages. """ logger = logger or NullLogger() with open(filename, 'w') as out: for user in sorted(allowed_users.keys()): pubkey = allowed_users[user] buf = 'ssh-rsa' key_data = _longstr(len(buf), 4) key_data += buf buf = _longstr(pubkey.e) key_data += _longstr(len(buf), 4) key_data += buf buf = _longstr(pubkey.n) key_data += _longstr(len(buf), 4) key_data += buf data = base64.b64encode(key_data) out.write('ssh-rsa %s %s\n\n' % (data, user)) if sys.platform == 'win32' and not HAVE_PYWIN32: #pragma no cover logger.warning("Can't make authorized keys file %r private", filename) else: make_private(filename)
def _longstr(num, length=0): """ Return binary string representation of `num`. """ buf = chr(num & 0xff) num >>= 8 while num: buf = chr(num & 0xff) + buf num >>= 8 while len(buf) < length: buf = chr(0) + buf return buf
OpenMDAO Home