Skip to content

Commit

Permalink
sshdriver: ssh_prefix to list and adjust users
Browse files Browse the repository at this point in the history
Convert ssh_prefix to a list and adjust all users. This was a string
previously, which led to get() and put() being broken on master, since
the string contained spaces. Convert it to a list, adjust the callers
and add new tests to test the SSHDriver against localhost. We also add a
new pytest option to supply the username, to make this work on travis.
Remove the test class used in the sshdriver tests while touching the file.

Fixes: 1d7b97b (driver/sshdriver: Allow whitespaces in filenames Prevent filenames with whitespaces from beeing split.)

Signed-off-by: Rouven Czerwinski <[email protected]>
  • Loading branch information
Emantor committed Mar 25, 2020
1 parent 1d7b97b commit a630074
Show file tree
Hide file tree
Showing 4 changed files with 115 additions and 67 deletions.
2 changes: 1 addition & 1 deletion .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ install:
- sudo mkdir /var/cache/labgrid && sudo chmod 1775 /var/cache/labgrid && sudo chown root:travis /var/cache/labgrid
script:
- pip install -e .
- pytest --cov-config .coveragerc --cov=labgrid --local-sshmanager
- pytest --cov-config .coveragerc --cov=labgrid --local-sshmanager --ssh-username travis
- python setup.py build_sphinx
- make -C man all
- git --no-pager diff --exit-code
Expand Down
45 changes: 23 additions & 22 deletions labgrid/driver/sshdriver.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,19 +31,20 @@ def __attrs_post_init__(self):
self._keepalive = None

def on_activate(self):
self.ssh_prefix = "-o LogLevel=ERROR"
self.ssh_prefix = ["-o", "LogLevel=ERROR"]
if self.keyfile:
keyfile_path = self.keyfile
if self.target.env:
keyfile_path = self.target.env.config.resolve_path(self.keyfile)
self.ssh_prefix += " -i {}".format(keyfile_path)
self.ssh_prefix += " -o PasswordAuthentication=no" if (
not self.networkservice.password) else ""
self.ssh_prefix += ["-i", keyfile_path ]
if not self.networkservice.password:
self.ssh_prefix += ["-o", "PasswordAuthentication=no"]

self.control = self._check_master()
self.ssh_prefix += " -F /dev/null"
self.ssh_prefix += " -o ControlPath={}".format(
self.control
) if self.control else ""
self.ssh_prefix += ["-F", "/dev/null"]
if self.control:
self.ssh_prefix += ["-o", "ControlPath={}".format(self.control)]

self._keepalive = None
self._start_keepalive();

Expand All @@ -60,10 +61,13 @@ def _start_own_master(self):
self.tmpdir, 'control-{}'.format(self.networkservice.address)
)
# use sshpass if we have a password
sshpass = "sshpass -e " if self.networkservice.password else ""
args = ("{}ssh -f {} -x -o ConnectTimeout=30 -o ControlPersist=300 -o UserKnownHostsFile=/dev/null -o StrictHostKeyChecking=no -o ServerAliveInterval=15 -MN -S {} -p {} {}@{}").format( # pylint: disable=line-too-long
sshpass, self.ssh_prefix, control, self.networkservice.port,
self.networkservice.username, self.networkservice.address).split(" ")
args = ["sshpass", "-e"] if self.networkservice.password else []
args += ["ssh", "-f", *self.ssh_prefix, "-x", "-o", "ConnectTimeout=30",
"-o", "ControlPersist=300", "-o",
"UserKnownHostsFile=/dev/null", "-o", "StrictHostKeyChecking=no",
"-o", "ServerAliveInterval=15", "-MN", "-S", control, "-p",
str(self.networkservice.port), "{}@{}".format(
self.networkservice.username, self.networkservice.address)]

env = os.environ.copy()
if self.networkservice.password:
Expand Down Expand Up @@ -126,13 +130,10 @@ def _run(self, cmd, codec="utf-8", decodeerrors="strict", timeout=None): # pylin
if not self._check_keepalive():
raise ExecutionError("Keepalive no longer running")

complete_cmd = "ssh -x {prefix} -p {port} {user}@{host} {cmd}".format(
user=self.networkservice.username,
host=self.networkservice.address,
cmd=cmd,
prefix=self.ssh_prefix,
port=self.networkservice.port
).split(' ')
complete_cmd = ["ssh", "-x", *self.ssh_prefix,
"-p", str(self.networkservice.port), "{}@{}".format(
self.networkservice.username, self.networkservice.address
)] + cmd.split(" ")
self.logger.debug("Sending command: %s", complete_cmd)
if self.stderr_merge:
stderr_pipe = subprocess.STDOUT
Expand Down Expand Up @@ -166,7 +167,7 @@ def get_status(self):
def put(self, filename, remotepath=''):
transfer_cmd = [
"scp",
self.ssh_prefix,
*self.ssh_prefix,
"-P", str(self.networkservice.port),
filename,
"{user}@{host}:{remotepath}".format(
Expand All @@ -193,7 +194,7 @@ def put(self, filename, remotepath=''):
def get(self, filename, destination="."):
transfer_cmd = [
"scp",
self.ssh_prefix,
*self.ssh_prefix,
"-P", str(self.networkservice.port),
"{user}@{host}:{filename}".format(
user=self.networkservice.username,
Expand Down Expand Up @@ -236,7 +237,7 @@ def _cleanup_own_master(self):

def _start_keepalive(self):
"""Starts a keepalive connection via the own or external master."""
args = ["ssh"] + self.ssh_prefix.split() + ["cat"]
args = ["ssh", *self.ssh_prefix, "cat"]

assert self._keepalive is None
self._keepalive = subprocess.Popen(
Expand Down
8 changes: 8 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,13 +146,17 @@ def pytest_addoption(parser):
help="Run sigrok usb tests with fx2lafw device (0925:3881)")
parser.addoption("--local-sshmanager", action="store_true",
help="Run SSHManager tests against localhost")
parser.addoption("--ssh-username", default=None,
help="SSH username to use for SSHDriver testing")

def pytest_configure(config):
# register an additional marker
config.addinivalue_line("markers",
"sigrokusb: enable fx2lafw USB tests (0925:3881)")
config.addinivalue_line("markers",
"localsshmanager: test SSHManager against Localhost")
config.addinivalue_line("markers",
"sshusername: test SSHDriver against Localhost")

def pytest_runtest_setup(item):
envmarker = item.get_closest_marker("sigrokusb")
Expand All @@ -163,3 +167,7 @@ def pytest_runtest_setup(item):
if envmarker is not None:
if item.config.getoption("--local-sshmanager") is False:
pytest.skip("SSHManager tests against localhost not enabled (enable with --local-sshmanager)")
envmarker = item.get_closest_marker("sshusername")
if envmarker is not None:
if item.config.getoption("--ssh-username") is None:
pytest.skip("SSHDriver tests against localhost not enabled (enable with --ssh-username <username>)")
127 changes: 83 additions & 44 deletions tests/test_sshdriver.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,49 +6,88 @@

@pytest.fixture(scope='function')
def ssh_driver_mocked_and_activated(target, mocker):
NetworkService(target, "service", "1.2.3.4", "root")
call = mocker.patch('subprocess.call')
call.return_value = 0
popen = mocker.patch('subprocess.Popen', autospec=True)
path = mocker.patch('os.path.exists')
path.return_value = True
instance_mock = mocker.MagicMock()
popen.return_value = instance_mock
instance_mock.wait = mocker.MagicMock(return_value=0)
NetworkService(target, "service", "1.2.3.4", "root")
call = mocker.patch('subprocess.call')
call.return_value = 0
popen = mocker.patch('subprocess.Popen', autospec=True)
path = mocker.patch('os.path.exists')
path.return_value = True
instance_mock = mocker.MagicMock()
popen.return_value = instance_mock
instance_mock.wait = mocker.MagicMock(return_value=0)
SSHDriver(target, "ssh")
s = target.get_driver("SSHDriver")
return s

def test_create_fail_missing_resource(target):
with pytest.raises(NoResourceFoundError):
SSHDriver(target, "ssh")
s = target.get_driver("SSHDriver")
return s

class TestSSHDriver:
def test_create_fail_missing_resource(self, target):
with pytest.raises(NoResourceFoundError):
SSHDriver(target, "ssh")

def test_create(self, target, mocker):
NetworkService(target, "service", "1.2.3.4", "root")
call = mocker.patch('subprocess.call')
call.return_value = 0
popen = mocker.patch('subprocess.Popen', autospec=True)
path = mocker.patch('os.path.exists')
path.return_value = True
instance_mock = mocker.MagicMock()
popen.return_value = instance_mock
instance_mock.wait = mocker.MagicMock(return_value=0)
s = SSHDriver(target, "ssh")
assert (isinstance(s, SSHDriver))

def test_run_check(self, ssh_driver_mocked_and_activated, mocker):
s = ssh_driver_mocked_and_activated
s._run = mocker.MagicMock(return_value=(['success'], [], 0))

def test_create(target, mocker):
NetworkService(target, "service", "1.2.3.4", "root")
call = mocker.patch('subprocess.call')
call.return_value = 0
popen = mocker.patch('subprocess.Popen', autospec=True)
path = mocker.patch('os.path.exists')
path.return_value = True
instance_mock = mocker.MagicMock()
popen.return_value = instance_mock
instance_mock.wait = mocker.MagicMock(return_value=0)
s = SSHDriver(target, "ssh")
assert isinstance(s, SSHDriver)

def test_run_check(ssh_driver_mocked_and_activated, mocker):
s = ssh_driver_mocked_and_activated
s._run = mocker.MagicMock(return_value=(['success'], [], 0))
res = s.run_check("test")
assert res == ['success']
res = s.run("test")
assert res == (['success'], [], 0)

def test_run_check_raise(ssh_driver_mocked_and_activated, mocker):
s = ssh_driver_mocked_and_activated
s._run = mocker.MagicMock(return_value=(['error'], [], 1))
with pytest.raises(ExecutionError):
res = s.run_check("test")
assert res == ['success']
res = s.run("test")
assert res == (['success'], [], 0)

def test_run_check_raise(self, ssh_driver_mocked_and_activated, mocker):
s = ssh_driver_mocked_and_activated
s._run = mocker.MagicMock(return_value=(['error'], [], 1))
with pytest.raises(ExecutionError):
res = s.run_check("test")
res = s.run("test")
assert res == (['error'], [], 1)
res = s.run("test")
assert res == (['error'], [], 1)

@pytest.fixture(scope='function')
def ssh_localhost(target, pytestconfig):
name = pytestconfig.getoption("--ssh-username")
NetworkService(target, "service", "localhost", name)
SSHDriver(target, "ssh")
s = target.get_driver("SSHDriver")
return s

@pytest.mark.sshusername
def test_local_put(ssh_localhost, tmpdir):
p = tmpdir.join("config.yaml")
p.write(
"""PUT Teststring"""
)

ssh_localhost.put(str(p), "/tmp/test_put_yaml")
assert open('/tmp/test_put_yaml', 'r').readlines() == [ "PUT Teststring" ]

@pytest.mark.sshusername
def test_local_get(ssh_localhost, tmpdir):
p = tmpdir.join("config.yaml")
p.write(
"""GET Teststring"""
)

ssh_localhost.get(str(p), "/tmp/test_get_yaml")
assert open('/tmp/test_get_yaml', 'r').readlines() == [ "GET Teststring" ]

@pytest.mark.sshusername
def test_local_run(ssh_localhost, tmpdir):

res = ssh_localhost.run("echo Hello")
assert res == (["Hello"], [], 0)

@pytest.mark.sshusername
def test_local_run_check(ssh_localhost, tmpdir):

res = ssh_localhost.run_check("echo Hello")
assert res == (["Hello"])

0 comments on commit a630074

Please sign in to comment.