easy_ssh.py

#! /usr/bin/env python
# -*- coding: utf-8 -*-


import paramiko
import time
import re
from paramiko.py3compat import u


class Host(object):

    linux_pattern = re.compile(r'(.+@.+):\s*.+([#>$])\s*')

    def __init__(self, host, username, password, proxy=None, proxy_username=None, proxy_password=None,
                 port=22, proxy_port=22):
        self.host = host
        self.username = username
        self.password = password
        self.port = port
        self.proxy = proxy
        self.proxy_username = proxy_username
        self.proxy_password = proxy_password
        self.proxy_port = proxy_port
        self.transport = None

    def login(self):
        if self.proxy:
            self.transport = paramiko.Transport(sock=(self.proxy, self.proxy_port))
            self.transport.connect(username=self.proxy_username, password=self.proxy_password)
        else:
            self.transport = paramiko.Transport(sock=(self.host, self.port))
            self.transport.connect(username=self.username, password=self.password)
        self.transport.set_keepalive(120)

    def get_session(self):
        try:
            ssh_channel = self.transport.open_session()
        except AttributeError:
            print('Need to login to host first')
            raise ConnectionAbortedError
        ssh_channel.get_pty()
        ssh_channel.invoke_shell()
        conn = Connection(ssh_channel, self.get_default_prompt(ssh_channel))
        if self.proxy:
            welcome_info = conn.send(cmd='ssh -p %s %s@%s' % (self.port, self.username, self.host),
                                     prompt=self.linux_pattern, interactive={'password:': self.password})
            conn.default_prompt = self._get_prompt_pattern(welcome_info.split('\r\n')[-1])
        conn.settimeout(30.0)
        return conn

    def get_default_prompt(self, chn):
        welcome_info = ''
        for _ in range(5):
            if chn.recv_ready():
                buff = chn.recv(8096).decode('utf-8')
                welcome_info += buff
            else:
                time.sleep(0.1)
        print(welcome_info)
        origin_str = welcome_info.split('\r\n')[-1]
        return self._get_prompt_pattern(origin_str)

    def _get_prompt_pattern(self, text):
        match_result = self.linux_pattern.match(text)
        if match_result:
            return re.compile(match_result.group(1) + r':\s*.+' + match_result.group(2))
        else:
            return text

    def sftp_get(self, remotepath, localpath, callback=None):
        sftp_client = paramiko.SFTPClient.from_transport(self.transport)
        sftp_client.get(remotepath, localpath, callback)
        sftp_client.close()

    def sftp_put(self, localpath, remotepath, callback=None, confirm=True):
        sftp_client = paramiko.SFTPClient.from_transport(self.transport)
        sftp_client.put(localpath, remotepath, callback, confirm)
        sftp_client.close()


class Connection(object):

    def __init__(self, chn, prompt):
        self.channel = chn
        self._default_prompt = prompt

    @property
    def default_prompt(self):
        return self._default_prompt

    @default_prompt.setter
    def default_prompt(self, value):
        self._default_prompt = value

    def send(self, cmd, prompt=None, interactive=None, interval=120):
        if not prompt:
            prompt = self._default_prompt
        if not interactive:
            data, end_with = self._send(cmd, prompt, interval)
            return data
        else:
            if isinstance(interactive, dict):
                data = ''
                prompt_list = [prompt]
                for key, value in interactive.items():
                    prompt_list.append(key)
                buff, end_with = self._send(cmd, prompt_list, interval)
                while True:
                    data += buff
                    if not self._match(prompt, end_with):
                        buff, end_with = self._send(interactive[end_with], prompt_list, interval)
                    else:
                        break
                return data
            else:
                raise TypeError

    def _send(self, cmd, prompt, interval):
        self.channel.sendall(cmd + '\n')
        data, endwith = self._recv(prompt, interval)
        print(data)
        return data, endwith

    def _recv(self, end_with, interval):
        data = ''
        no_resp_time = 0
        while no_resp_time < interval:
            if self.channel.recv_ready():
                no_resp_time = 0
                buff = u(self.channel.recv(1024))
                data += buff
                if isinstance(end_with, list):
                    for item in end_with:
                        if self._is_end(buff, item):
                            return data, item
                else:
                    if self._is_end(buff, end_with):
                        return data, end_with
                    else:
                        continue
            else:
                time.sleep(0.2)
                no_resp_time += 0.2
                continue
        raise TimeoutError

    @staticmethod
    def _is_end(data, end_with):
        if isinstance(end_with, str):
            if end_with in data:
                return True
        else:
            if end_with.search(data):
                return True
        return False

    def close(self):
        return self.channel.close()

    @staticmethod
    def _match(pattern, string):
        if isinstance(pattern, str):
            if pattern != string:
                return False
        elif hasattr(pattern, 'pattern') and isinstance(string, str):
            if not pattern.match(string):
                return False
        elif hasattr(pattern, 'pattern') and hasattr(pattern, 'pattern'):
            if pattern.pattern != string.pattern:
                return False
        else:
            raise TypeError
        return True

    def settimeout(self, timeout):
        return self.channel.settimeout(timeout)
©著作权归作者所有,转载或内容合作请联系作者
平台声明:文章内容(如有图片或视频亦包括在内)由作者上传并发布,文章内容仅代表作者本人观点,简书系信息发布平台,仅提供信息存储服务。

推荐阅读更多精彩内容