forked from cloudflare/cloudflared
-
Notifications
You must be signed in to change notification settings - Fork 1
/
tests.py
195 lines (158 loc) · 6.02 KB
/
tests.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
"""
Cloudflared Integration tests
"""
import unittest
import subprocess
import os
import tempfile
from contextlib import contextmanager
from pexpect import pxssh
class TestSSHBase(unittest.TestCase):
"""
SSH test base class containing constants and helper funcs
"""
HOSTNAME = os.environ["SSH_HOSTNAME"]
SSH_USER = os.environ["SSH_USER"]
SSH_TARGET = f"{SSH_USER}@{HOSTNAME}"
AUTHORIZED_KEYS_SSH_CONFIG = os.environ["AUTHORIZED_KEYS_SSH_CONFIG"]
SHORT_LIVED_CERT_SSH_CONFIG = os.environ["SHORT_LIVED_CERT_SSH_CONFIG"]
SSH_OPTIONS = {"StrictHostKeyChecking": "no"}
@classmethod
def get_ssh_command(cls, pty=True):
"""
Return ssh command arg list. If pty is true, a PTY is forced for the session.
"""
cmd = [
"ssh",
"-o",
"StrictHostKeyChecking=no",
"-F",
cls.AUTHORIZED_KEYS_SSH_CONFIG,
cls.SSH_TARGET,
]
if not pty:
cmd += ["-T"]
else:
cmd += ["-tt"]
return cmd
@classmethod
@contextmanager
def ssh_session_manager(cls, *args, **kwargs):
"""
Context manager for interacting with a pxssh session.
Disables pty echo on the remote server and ensures session is terminated afterward.
"""
session = pxssh.pxssh(options=cls.SSH_OPTIONS)
session.login(
cls.HOSTNAME,
username=cls.SSH_USER,
original_prompt=r"[#@$]",
ssh_config=kwargs.get("ssh_config", cls.AUTHORIZED_KEYS_SSH_CONFIG),
ssh_tunnels=kwargs.get("ssh_tunnels", {}),
)
try:
session.sendline("stty -echo")
session.prompt()
yield session
finally:
session.logout()
@staticmethod
def get_command_output(session, cmd):
"""
Executes command on remote ssh server and waits for prompt.
Returns command output
"""
session.sendline(cmd)
session.prompt()
return session.before.decode().strip()
def exec_command(self, cmd, shell=False):
"""
Executes command locally. Raises Assertion error for non-zero return code.
Returns stdout and stderr
"""
proc = subprocess.Popen(
cmd, stderr=subprocess.PIPE, stdout=subprocess.PIPE, shell=shell
)
raw_out, raw_err = proc.communicate()
out = raw_out.decode()
err = raw_err.decode()
self.assertEqual(proc.returncode, 0, msg=f"stdout: {out} stderr: {err}")
return out.strip(), err.strip()
class TestSSHCommandExec(TestSSHBase):
"""
Tests inline ssh command exec
"""
# Name of file to be downloaded over SCP on remote server.
REMOTE_SCP_FILENAME = os.environ["REMOTE_SCP_FILENAME"]
@classmethod
def get_scp_base_command(cls):
return [
"scp",
"-o",
"StrictHostKeyChecking=no",
"-v",
"-F",
cls.AUTHORIZED_KEYS_SSH_CONFIG,
]
@unittest.skip(
"This creates files on the remote. Should be skipped until server is dockerized."
)
def test_verbose_scp_sink_mode(self):
with tempfile.NamedTemporaryFile() as fl:
self.exec_command(
self.get_scp_base_command() + [fl.name, f"{self.SSH_TARGET}:"]
)
def test_verbose_scp_source_mode(self):
with tempfile.TemporaryDirectory() as tmpdirname:
self.exec_command(
self.get_scp_base_command()
+ [f"{self.SSH_TARGET}:{self.REMOTE_SCP_FILENAME}", tmpdirname]
)
local_filename = os.path.join(tmpdirname, self.REMOTE_SCP_FILENAME)
self.assertTrue(os.path.exists(local_filename))
self.assertTrue(os.path.getsize(local_filename) > 0)
def test_pty_command(self):
base_cmd = self.get_ssh_command()
out, _ = self.exec_command(base_cmd + ["whoami"])
self.assertEqual(out.strip().lower(), self.SSH_USER.lower())
out, _ = self.exec_command(base_cmd + ["tty"])
self.assertNotEqual(out, "not a tty")
def test_non_pty_command(self):
base_cmd = self.get_ssh_command(pty=False)
out, _ = self.exec_command(base_cmd + ["whoami"])
self.assertEqual(out.strip().lower(), self.SSH_USER.lower())
out, _ = self.exec_command(base_cmd + ["tty"])
self.assertEqual(out, "not a tty")
class TestSSHShell(TestSSHBase):
"""
Tests interactive SSH shell
"""
# File path to a file on the remote server with root only read privileges.
ROOT_ONLY_TEST_FILE_PATH = os.environ["ROOT_ONLY_TEST_FILE_PATH"]
def test_ssh_pty(self):
with self.ssh_session_manager() as session:
# Test shell launched as correct user
username = self.get_command_output(session, "whoami")
self.assertEqual(username.lower(), self.SSH_USER.lower())
# Test USER env variable set
user_var = self.get_command_output(session, "echo $USER")
self.assertEqual(user_var.lower(), self.SSH_USER.lower())
# Test HOME env variable set to true user home.
home_env = self.get_command_output(session, "echo $HOME")
pwd = self.get_command_output(session, "pwd")
self.assertEqual(pwd, home_env)
# Test shell launched in correct user home dir.
self.assertIn(username, pwd)
# Ensure shell launched with correct user's permissions and privs.
# Can't read root owned 0700 files.
output = self.get_command_output(
session, f"cat {self.ROOT_ONLY_TEST_FILE_PATH}"
)
self.assertIn("Permission denied", output)
def test_short_lived_cert_auth(self):
with self.ssh_session_manager(
ssh_config=self.SHORT_LIVED_CERT_SSH_CONFIG
) as session:
username = self.get_command_output(session, "whoami")
self.assertEqual(username.lower(), self.SSH_USER.lower())
unittest.main()