Bootstrap

Go 实现SFTP连接服务

我们将SFTP连接和处理逻辑,以及登录账户信息封装,这样可以在不同的地方重用代码,并且可以轻松地更改登录凭据。下面我将演示如何使用Go语言中的结构体来封装这些信息,并实现一个简单的SFTP服务器:

package main

import (
	"errors"
	"fmt"
	"io"
	"io/ioutil"
	"log"
	"net"
	"os"
	"path/filepath"
	"strings"

	"github.com/pkg/sftp"
	"golang.org/x/crypto/ssh"
)

// 实现自定义请求处理程序
type CustomHandler struct {
	baseDir string // 基础目录,用于限制SFTP操作在特定目录下
}

func (h *CustomHandler) Fileread(request *sftp.Request) (io.ReaderAt, error) {
	path := filepath.Join(h.baseDir, request.Filepath)
	file, err := os.Open(path)
	if err != nil {
		return nil, err
	}
	return file, nil
}

func (h *CustomHandler) Filewrite(request *sftp.Request) (io.WriterAt, error) {
	path := filepath.Join(h.baseDir, request.Filepath)
	file, err := os.OpenFile(path, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, os.ModePerm)
	if err != nil {
		return nil, err
	}
	return file, nil
}

func (h *CustomHandler) Filecmd(request *sftp.Request) error {
	path := filepath.Join(h.baseDir, request.Filepath)
	switch request.Method {
	case "Rename":
		// 对于重命名,request.Target 会包含新的文件名
		targetPath := filepath.Join(h.baseDir, request.Target)
		// 确保目标路径不在 baseDir 之外
		if !strings.HasPrefix(targetPath, h.baseDir) {
			return errors.New("invalid target path")
		}
		return os.Rename(path, targetPath)
	case "Rmdir":
		// 删除目录
		return os.Remove(path)
	case "Mkdir":
		// 创建目录
		return os.Mkdir(path, os.ModePerm)
	case "Remove":
		return os.Remove(path)
	case "Setstat", "Link", "Symlink":
		fallthrough
	default:
		log.Printf("Filecmd request %v", request)
		return errors.New("operation not supported")
	}
	return nil
}

func (h *CustomHandler) Filelist(request *sftp.Request) (sftp.ListerAt, error) {
	path := filepath.Join(h.baseDir, request.Filepath)
	switch request.Method {
	case "List":
		// 检索目录内容
		files, err := ioutil.ReadDir(path)
		if err != nil {
			return nil, err
		}
		// 将 os.FileInfo 列表转换为 sftp.ListerAt
		return listerAt(files), nil
	case "Stat":
		info, err := os.Stat(path)
		if err != nil {
			return nil, err
		}
		return listerAt([]os.FileInfo{info}), nil
	case "Readlink":
		target, err := os.Readlink(path)
		if err != nil {
			return nil, err
		}
		info, err := os.Lstat(target)
		if err != nil {
			return nil, err
		}
		return listerAt([]os.FileInfo{info}), nil
	}
	return nil, nil
}

// listerAt 是一个辅助类型,用于实现 sftp.ListerAt 接口
type listerAt []os.FileInfo

func (l listerAt) ListAt(list []os.FileInfo, offset int64) (int, error) {
	if offset >= int64(len(l)) {
		return 0, io.EOF
	}
	n := copy(list, l[offset:])
	return n, nil
}

type SFTPServer struct {
	HostKeyPath string
	AuthUser    string
	AuthPass    string
	Port        string
}

func NewSFTPServer(hostKeyPath, user, pass, port string) *SFTPServer {
	return &SFTPServer{
		HostKeyPath: hostKeyPath,
		AuthUser:    user,
		AuthPass:    pass,
		Port:        port,
	}
}

func (server *SFTPServer) Start() {
	config := &ssh.ServerConfig{
		PasswordCallback: server.passwordCallback,
	}

	privateBytes, err := os.ReadFile(server.HostKeyPath)
	if err != nil {
		log.Fatalf("Failed to load host key: %v", err)
	}

	private, err := ssh.ParsePrivateKey(privateBytes)
	if err != nil {
		log.Fatalf("Failed to parse host key: %v", err)
	}

	config.AddHostKey(private)

	listener, err := net.Listen("tcp", "0.0.0.0:"+server.Port)
	if err != nil {
		log.Fatalf("Failed to listen on port %s: %v", server.Port, err)
	}

	log.Printf("Listening on port %s...", server.Port)
	for {
		conn, err := listener.Accept()
		if err != nil {
			log.Printf("Failed to accept incoming connection: %v", err)
			continue
		}

		go server.handleConn(conn, config)
	}
}

func (server *SFTPServer) handleConn(nConn net.Conn, config *ssh.ServerConfig) {
	sshConn, chans, reqs, err := ssh.NewServerConn(nConn, config)
	if err != nil {
		log.Printf("Failed to handshake: %v", err)
		return
	}
	defer sshConn.Close()

	go ssh.DiscardRequests(reqs)

	for newChannel := range chans {
		if newChannel.ChannelType() == "session" {
			go server.handleChannel(newChannel)
		} else {
			newChannel.Reject(ssh.UnknownChannelType, "unknown channel type")
		}
	}
}

func (server *SFTPServer) handleChannel(newChannel ssh.NewChannel) {
	channel, requests, err := newChannel.Accept()
	if err != nil {
		log.Printf("Could not accept channel (%s)", err)
		return
	}
	defer channel.Close()

	// 只在需要时创建 SFTP 服务器实例
	var sftpServer *sftp.Server

	for req := range requests {
		switch req.Type {
		case "subsystem":
			if string(req.Payload[4:]) == "sftp" {
				req.Reply(true, nil)

				baseDir, err := server.getWorkDir()
				if err != nil {
					log.Printf("Failed to create SFTP work dir: %v", err)
					return
				}
				handler := &CustomHandler{
					baseDir: baseDir,
				}
				sftpServer := sftp.NewRequestServer(channel, sftp.Handlers{
					FileGet:  handler,
					FilePut:  handler,
					FileCmd:  handler,
					FileList: handler,
				})
				if err := sftpServer.Serve(); err == io.EOF {
					log.Printf("SFTP client disconnected")
					return
				} else if err != nil {
					log.Printf("SFTP server completed with error: %v", err)
					return
				}
			} else {
				req.Reply(false, nil)
			}
		default:
			req.Reply(false, nil)
		}
	}

	if sftpServer == nil {
		log.Printf("No SFTP subsystem started")
		return
	}
}

func (server *SFTPServer) passwordCallback(c ssh.ConnMetadata, pass []byte) (*ssh.Permissions, error) {
	if c.User() == server.AuthUser && string(pass) == server.AuthPass {
		return nil, nil
	}
	return nil, fmt.Errorf("password rejected for %q", c.User())
}

func (server *SFTPServer) getWorkDir() (string, error) {
	// 获取当前工作目录
	workingDir, err := os.Getwd()
	if err != nil {
		log.Fatalf("Unable to get current working directory: %v", err)
		return "", err
	}

	// 构建操作目录的完整路径
	baseDir := filepath.Join(workingDir, "sftp_tmp")

	// 确保目录存在
	err = os.MkdirAll(baseDir, os.ModePerm)
	if err != nil {
		log.Fatalf("Unable to create tmp directory: %v", err)
		return "", err
	}
	return baseDir, nil
}

func main() {
	sftpServer := NewSFTPServer("~/.ssh/id_rsa", "root", "123456", "9527")
	sftpServer.Start()
}

效果图:
在这里插入图片描述

在这个封装中,SFTPServer结构体包含SSH服务器的配置信息,如主机密钥路径、授权用户名、密码和监听端口。NewSFTPServer函数是构造函数,用于创建SFTPServer实例。Start方法启动SFTP服务器并监听指定端口。

注意:请确保替换 hostKeyPathusernamepasswordport 这些参数为您自己的设置。~/.ssh/id_rsa 是您的服务器私钥文件的路径,您需要将其替换为实际的文件路径。root123456 是您希望用户使用的登录凭证。9527 是您希望SFTP服务器监听的端口号。

在上述封装代码中,Start 方法会启动SFTP服务器并等待连接。对于每个新连接,都会在一个新的goroutine中调用 handleConn 方法,进行SSH握手并处理SFTP会话。handleConn 方法会处理新通道,并将每个新通道传递给 handleChannel 方法,后者配置并启动SFTP服务。

passwordCallback 方法是一个回调函数,用于在SSH握手过程中验证用户凭证。如果提供的用户名和密码与结构体中定义的匹配,则会允许连接。

最后,main 函数实例化了 SFTPServer 并启动了SFTP服务。您需要确保您的系统中有SSH私钥文件,并且您有权使用指定的端口。

在实际部署中,您应该使用更安全的方法存储用户凭据,例如使用加密的方式,或者通过集成现有的用户管理系统,而不是将用户名和密码硬编码在代码中。

此外,您可能还需要添加更多的功能,例如支持基于公钥的认证、限制用户的文件系统访问权限、记录日志到文件等。这些功能可以根据需要扩展SFTPServer结构体和相关方法。

在您的main函数中调用sftpServer.Start(),就可以启动SFTP服务器。记得在正式环境中处理好错误和日志记录,确保服务的稳定性和安全性。

;