mirror of
https://github.com/curl/curl.git
synced 2024-11-21 01:16:58 +08:00
57cc523378
These all seem reasonable to enable for this code.
454 lines
16 KiB
Python
Executable File
454 lines
16 KiB
Python
Executable File
#!/usr/bin/env python3
|
|
# -*- coding: utf-8 -*-
|
|
#
|
|
# Project ___| | | | _ \| |
|
|
# / __| | | | |_) | |
|
|
# | (__| |_| | _ <| |___
|
|
# \___|\___/|_| \_\_____|
|
|
#
|
|
# Copyright (C) Daniel Stenberg, <daniel@haxx.se>, et al.
|
|
#
|
|
# This software is licensed as described in the file COPYING, which
|
|
# you should have received as part of this distribution. The terms
|
|
# are also available at https://curl.se/docs/copyright.html.
|
|
#
|
|
# You may opt to use, copy, modify, merge, publish, distribute and/or sell
|
|
# copies of the Software, and permit persons to whom the Software is
|
|
# furnished to do so, under the terms of the COPYING file.
|
|
#
|
|
# This software is distributed on an "AS IS" basis, WITHOUT WARRANTY OF ANY
|
|
# KIND, either express or implied.
|
|
#
|
|
# SPDX-License-Identifier: curl
|
|
#
|
|
"""Server for testing SMB."""
|
|
|
|
from __future__ import (absolute_import, division, print_function,
|
|
unicode_literals)
|
|
|
|
import argparse
|
|
import logging
|
|
import os
|
|
import signal
|
|
import sys
|
|
import tempfile
|
|
import threading
|
|
|
|
# Import our curl test data helper
|
|
from util import ClosingFileHandler, TestData
|
|
|
|
if sys.version_info.major >= 3:
|
|
import configparser
|
|
else:
|
|
import ConfigParser as configparser
|
|
|
|
# impacket needs to be installed in the Python environment
|
|
try:
|
|
import impacket # noqa: F401
|
|
except ImportError:
|
|
sys.stderr.write(
|
|
'Warning: Python package impacket is required for smb testing; '
|
|
'use pip or your package manager to install it\n')
|
|
sys.exit(1)
|
|
from impacket import smb as imp_smb
|
|
from impacket import smbserver as imp_smbserver
|
|
from impacket.nt_errors import (STATUS_ACCESS_DENIED, STATUS_NO_SUCH_FILE,
|
|
STATUS_SUCCESS)
|
|
|
|
log = logging.getLogger(__name__)
|
|
SERVER_MAGIC = "SERVER_MAGIC"
|
|
TESTS_MAGIC = "TESTS_MAGIC"
|
|
VERIFIED_REQ = "verifiedserver"
|
|
VERIFIED_RSP = "WE ROOLZ: {pid}\n"
|
|
|
|
|
|
class ShutdownHandler(threading.Thread):
|
|
"""
|
|
Cleanly shut down the SMB server.
|
|
|
|
This can only be done from another thread while the server is in
|
|
serve_forever(), so a thread is spawned here that waits for a shutdown
|
|
signal before doing its thing. Use in a with statement around the
|
|
serve_forever() call.
|
|
"""
|
|
|
|
def __init__(self, server):
|
|
super(ShutdownHandler, self).__init__()
|
|
self.server = server
|
|
self.shutdown_event = threading.Event()
|
|
|
|
def __enter__(self):
|
|
self.start()
|
|
signal.signal(signal.SIGINT, self._sighandler)
|
|
signal.signal(signal.SIGTERM, self._sighandler)
|
|
|
|
def __exit__(self, *_):
|
|
# Call for shutdown just in case it wasn't done already
|
|
self.shutdown_event.set()
|
|
# Wait for thread, and therefore also the server, to finish
|
|
self.join()
|
|
# Uninstall our signal handlers
|
|
signal.signal(signal.SIGINT, signal.SIG_DFL)
|
|
signal.signal(signal.SIGTERM, signal.SIG_DFL)
|
|
# Delete any temporary files created by the server during its run
|
|
log.info("Deleting %d temporary file(s)", len(self.server.tmpfiles))
|
|
for f in self.server.tmpfiles:
|
|
os.unlink(f)
|
|
|
|
def _sighandler(self, _signum, _frame):
|
|
# Wake up the cleanup task
|
|
self.shutdown_event.set()
|
|
|
|
def run(self):
|
|
# Wait for shutdown signal
|
|
self.shutdown_event.wait()
|
|
# Notify the server to shut down
|
|
self.server.shutdown()
|
|
|
|
|
|
def smbserver(options):
|
|
"""Start up a TCP SMB server that serves forever."""
|
|
if options.pidfile:
|
|
pid = os.getpid()
|
|
# see tests/server/util.c function write_pidfile
|
|
if os.name == "nt":
|
|
pid += 65536
|
|
with open(options.pidfile, "w") as f:
|
|
f.write(str(pid))
|
|
|
|
# Here we write a mini config for the server
|
|
smb_config = configparser.ConfigParser()
|
|
smb_config.add_section("global")
|
|
smb_config.set("global", "server_name", "SERVICE")
|
|
smb_config.set("global", "server_os", "UNIX")
|
|
smb_config.set("global", "server_domain", "WORKGROUP")
|
|
smb_config.set("global", "log_file", "None")
|
|
smb_config.set("global", "credentials_file", "")
|
|
|
|
# We need a share which allows us to test that the server is running
|
|
smb_config.add_section("SERVER")
|
|
smb_config.set("SERVER", "comment", "server function")
|
|
smb_config.set("SERVER", "read only", "yes")
|
|
smb_config.set("SERVER", "share type", "0")
|
|
smb_config.set("SERVER", "path", SERVER_MAGIC)
|
|
|
|
# Have a share for tests. These files will be autogenerated from the
|
|
# test input.
|
|
smb_config.add_section("TESTS")
|
|
smb_config.set("TESTS", "comment", "tests")
|
|
smb_config.set("TESTS", "read only", "yes")
|
|
smb_config.set("TESTS", "share type", "0")
|
|
smb_config.set("TESTS", "path", TESTS_MAGIC)
|
|
|
|
if not options.srcdir or not os.path.isdir(options.srcdir):
|
|
raise ScriptError("--srcdir is mandatory")
|
|
|
|
test_data_dir = os.path.join(options.srcdir, "data")
|
|
|
|
smb_server = TestSmbServer((options.host, options.port),
|
|
config_parser=smb_config,
|
|
test_data_directory=test_data_dir)
|
|
log.info("[SMB] setting up SMB server on port %s", options.port)
|
|
smb_server.processConfigFile()
|
|
|
|
# Start a thread that cleanly shuts down the server on a signal
|
|
with ShutdownHandler(smb_server):
|
|
# This will block until smb_server.shutdown() is called
|
|
smb_server.serve_forever()
|
|
|
|
return 0
|
|
|
|
|
|
class TestSmbServer(imp_smbserver.SMBSERVER):
|
|
"""
|
|
Test server for SMB which subclasses the impacket SMBSERVER and provides
|
|
test functionality.
|
|
"""
|
|
|
|
def __init__(self,
|
|
address,
|
|
config_parser=None,
|
|
test_data_directory=None):
|
|
imp_smbserver.SMBSERVER.__init__(self,
|
|
address,
|
|
config_parser=config_parser)
|
|
self.tmpfiles = []
|
|
|
|
# Set up a test data object so we can get test data later.
|
|
self.ctd = TestData(test_data_directory)
|
|
|
|
# Override smbComNtCreateAndX so we can pretend to have files which
|
|
# don't exist.
|
|
self.hookSmbCommand(imp_smb.SMB.SMB_COM_NT_CREATE_ANDX,
|
|
self.create_and_x)
|
|
|
|
def create_and_x(self, conn_id, smb_server, smb_command, recv_packet):
|
|
"""
|
|
Our version of smbComNtCreateAndX looks for special test files and
|
|
fools the rest of the framework into opening them as if they were
|
|
normal files.
|
|
"""
|
|
conn_data = smb_server.getConnectionData(conn_id)
|
|
|
|
# Wrap processing in a try block which allows us to throw SmbError
|
|
# to control the flow.
|
|
try:
|
|
ncax_parms = imp_smb.SMBNtCreateAndX_Parameters(
|
|
smb_command["Parameters"])
|
|
|
|
path = self.get_share_path(conn_data,
|
|
ncax_parms["RootFid"],
|
|
recv_packet["Tid"])
|
|
log.info("[SMB] Requested share path: %s", path)
|
|
|
|
disposition = ncax_parms["Disposition"]
|
|
log.debug("[SMB] Requested disposition: %s", disposition)
|
|
|
|
# Currently we only support reading files.
|
|
if disposition != imp_smb.FILE_OPEN:
|
|
raise SmbError(STATUS_ACCESS_DENIED,
|
|
"Only support reading files")
|
|
|
|
# Check to see if the path we were given is actually a
|
|
# magic path which needs generating on the fly.
|
|
if path not in [SERVER_MAGIC, TESTS_MAGIC]:
|
|
# Pass the command onto the original handler.
|
|
return imp_smbserver.SMBCommands.smbComNtCreateAndX(conn_id,
|
|
smb_server,
|
|
smb_command,
|
|
recv_packet)
|
|
|
|
flags2 = recv_packet["Flags2"]
|
|
ncax_data = imp_smb.SMBNtCreateAndX_Data(flags=flags2,
|
|
data=smb_command[
|
|
"Data"])
|
|
requested_file = imp_smbserver.decodeSMBString(
|
|
flags2,
|
|
ncax_data["FileName"])
|
|
log.debug("[SMB] User requested file '%s'", requested_file)
|
|
|
|
if path == SERVER_MAGIC:
|
|
fid, full_path = self.get_server_path(requested_file)
|
|
else:
|
|
assert path == TESTS_MAGIC
|
|
fid, full_path = self.get_test_path(requested_file)
|
|
|
|
self.tmpfiles.append(full_path)
|
|
|
|
resp_parms = imp_smb.SMBNtCreateAndXResponse_Parameters()
|
|
resp_data = ""
|
|
|
|
# Simple way to generate a fid
|
|
if len(conn_data["OpenedFiles"]) == 0:
|
|
fakefid = 1
|
|
else:
|
|
fakefid = conn_data["OpenedFiles"].keys()[-1] + 1
|
|
resp_parms["Fid"] = fakefid
|
|
resp_parms["CreateAction"] = disposition
|
|
|
|
if os.path.isdir(path):
|
|
resp_parms[
|
|
"FileAttributes"] = imp_smb.SMB_FILE_ATTRIBUTE_DIRECTORY
|
|
resp_parms["IsDirectory"] = 1
|
|
else:
|
|
resp_parms["IsDirectory"] = 0
|
|
resp_parms["FileAttributes"] = ncax_parms["FileAttributes"]
|
|
|
|
# Get this file's information
|
|
resp_info, error_code = imp_smbserver.queryPathInformation(
|
|
os.path.dirname(full_path), os.path.basename(full_path),
|
|
level=imp_smb.SMB_QUERY_FILE_ALL_INFO)
|
|
|
|
if error_code != STATUS_SUCCESS:
|
|
raise SmbError(error_code, "Failed to query path info")
|
|
|
|
resp_parms["CreateTime"] = resp_info["CreationTime"]
|
|
resp_parms["LastAccessTime"] = resp_info[
|
|
"LastAccessTime"]
|
|
resp_parms["LastWriteTime"] = resp_info["LastWriteTime"]
|
|
resp_parms["LastChangeTime"] = resp_info[
|
|
"LastChangeTime"]
|
|
resp_parms["FileAttributes"] = resp_info[
|
|
"ExtFileAttributes"]
|
|
resp_parms["AllocationSize"] = resp_info[
|
|
"AllocationSize"]
|
|
resp_parms["EndOfFile"] = resp_info["EndOfFile"]
|
|
|
|
# Let's store the fid for the connection
|
|
# smbServer.log("Create file %s, mode:0x%x" % (pathName, mode))
|
|
conn_data["OpenedFiles"][fakefid] = {}
|
|
conn_data["OpenedFiles"][fakefid]["FileHandle"] = fid
|
|
conn_data["OpenedFiles"][fakefid]["FileName"] = path
|
|
conn_data["OpenedFiles"][fakefid]["DeleteOnClose"] = False
|
|
|
|
except SmbError as s:
|
|
log.debug("[SMB] SmbError hit: %s", s)
|
|
error_code = s.error_code
|
|
resp_parms = ""
|
|
resp_data = ""
|
|
|
|
resp_cmd = imp_smb.SMBCommand(imp_smb.SMB.SMB_COM_NT_CREATE_ANDX)
|
|
resp_cmd["Parameters"] = resp_parms
|
|
resp_cmd["Data"] = resp_data
|
|
smb_server.setConnectionData(conn_id, conn_data)
|
|
|
|
return [resp_cmd], None, error_code
|
|
|
|
def get_share_path(self, conn_data, root_fid, tid):
|
|
conn_shares = conn_data["ConnectedShares"]
|
|
|
|
if tid in conn_shares:
|
|
if root_fid > 0:
|
|
# If we have a rootFid, the path is relative to that fid
|
|
path = conn_data["OpenedFiles"][root_fid]["FileName"]
|
|
log.debug("RootFid present %s!" % path)
|
|
else:
|
|
if "path" in conn_shares[tid]:
|
|
path = conn_shares[tid]["path"]
|
|
else:
|
|
raise SmbError(STATUS_ACCESS_DENIED,
|
|
"Connection share had no path")
|
|
else:
|
|
raise SmbError(imp_smbserver.STATUS_SMB_BAD_TID,
|
|
"TID was invalid")
|
|
|
|
return path
|
|
|
|
def get_server_path(self, requested_filename):
|
|
log.debug("[SMB] Get server path '%s'", requested_filename)
|
|
|
|
if requested_filename not in [VERIFIED_REQ]:
|
|
raise SmbError(STATUS_NO_SUCH_FILE, "Couldn't find the file")
|
|
|
|
fid, filename = tempfile.mkstemp()
|
|
log.debug("[SMB] Created %s (%d) for storing '%s'",
|
|
filename, fid, requested_filename)
|
|
|
|
contents = ""
|
|
|
|
if requested_filename == VERIFIED_REQ:
|
|
log.debug("[SMB] Verifying server is alive")
|
|
pid = os.getpid()
|
|
# see tests/server/util.c function write_pidfile
|
|
if os.name == "nt":
|
|
pid += 65536
|
|
contents = VERIFIED_RSP.format(pid=pid).encode('utf-8')
|
|
|
|
self.write_to_fid(fid, contents)
|
|
return fid, filename
|
|
|
|
def write_to_fid(self, fid, contents):
|
|
# Write the contents to file descriptor
|
|
os.write(fid, contents)
|
|
os.fsync(fid)
|
|
|
|
# Rewind the file to the beginning so a read gets us the contents
|
|
os.lseek(fid, 0, os.SEEK_SET)
|
|
|
|
def get_test_path(self, requested_filename):
|
|
log.info("[SMB] Get reply data from 'test%s'", requested_filename)
|
|
|
|
fid, filename = tempfile.mkstemp()
|
|
log.debug("[SMB] Created %s (%d) for storing test '%s'",
|
|
filename, fid, requested_filename)
|
|
|
|
try:
|
|
contents = self.ctd.get_test_data(requested_filename).encode('utf-8')
|
|
self.write_to_fid(fid, contents)
|
|
return fid, filename
|
|
|
|
except Exception:
|
|
log.exception("Failed to make test file")
|
|
raise SmbError(STATUS_NO_SUCH_FILE, "Failed to make test file")
|
|
|
|
|
|
class SmbError(Exception):
|
|
def __init__(self, error_code, error_message):
|
|
super(SmbError, self).__init__(error_message)
|
|
self.error_code = error_code
|
|
|
|
|
|
class ScriptRC(object):
|
|
"""Enum for script return codes."""
|
|
|
|
SUCCESS = 0
|
|
FAILURE = 1
|
|
EXCEPTION = 2
|
|
|
|
|
|
class ScriptError(Exception):
|
|
pass
|
|
|
|
|
|
def get_options():
|
|
parser = argparse.ArgumentParser()
|
|
|
|
parser.add_argument("--port", action="store", default=9017,
|
|
type=int, help="port to listen on")
|
|
parser.add_argument("--host", action="store", default="127.0.0.1",
|
|
help="host to listen on")
|
|
parser.add_argument("--verbose", action="store", type=int, default=0,
|
|
help="verbose output")
|
|
parser.add_argument("--pidfile", action="store",
|
|
help="file name for the PID")
|
|
parser.add_argument("--logfile", action="store",
|
|
help="file name for the log")
|
|
parser.add_argument("--srcdir", action="store", help="test directory")
|
|
parser.add_argument("--id", action="store", help="server ID")
|
|
parser.add_argument("--ipv4", action="store_true", default=0,
|
|
help="IPv4 flag")
|
|
|
|
return parser.parse_args()
|
|
|
|
|
|
def setup_logging(options):
|
|
"""Set up logging from the command line options."""
|
|
root_logger = logging.getLogger()
|
|
add_stdout = False
|
|
|
|
formatter = logging.Formatter("%(asctime)s %(levelname)-5.5s %(message)s")
|
|
|
|
# Write out to a logfile
|
|
if options.logfile:
|
|
handler = ClosingFileHandler(options.logfile)
|
|
handler.setFormatter(formatter)
|
|
handler.setLevel(logging.DEBUG)
|
|
root_logger.addHandler(handler)
|
|
else:
|
|
# The logfile wasn't specified. Add a stdout logger.
|
|
add_stdout = True
|
|
|
|
if options.verbose:
|
|
# Add a stdout logger as well in verbose mode
|
|
root_logger.setLevel(logging.DEBUG)
|
|
add_stdout = True
|
|
else:
|
|
root_logger.setLevel(logging.WARNING)
|
|
|
|
if add_stdout:
|
|
stdout_handler = logging.StreamHandler(sys.stdout)
|
|
stdout_handler.setFormatter(formatter)
|
|
stdout_handler.setLevel(logging.DEBUG)
|
|
root_logger.addHandler(stdout_handler)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
# Get the options from the user.
|
|
options = get_options()
|
|
|
|
# Setup logging using the user options
|
|
setup_logging(options)
|
|
|
|
# Run main script.
|
|
try:
|
|
rc = smbserver(options)
|
|
except Exception:
|
|
log.exception('Error in SMB server')
|
|
rc = ScriptRC.EXCEPTION
|
|
|
|
if options.pidfile and os.path.isfile(options.pidfile):
|
|
os.unlink(options.pidfile)
|
|
|
|
log.info("[SMB] Returning %d", rc)
|
|
sys.exit(rc)
|