SSH port forwarding with Go
SSH port forwarding is a common practice to make connections to services that could not be exposed directly to the public internet. If you're running your own database server, most likely you have a strict firewall rule in place that only allows connects from a known IP addresses and a big chance that the only publicly exposed port on your machine(s) is for 22/ssh.
Example
To start forwarding ports you can use ssh
command:
ssh -Ng -L 5000:localhost:5432 user@myapp.com
That will start server on localhost:5000 and forward connection to localhost:5432 on myapp.com machine. Flag description (per man page):
-N
- Do not execute a remote command. This is useful for just forwarding ports-g
- Allows remote hosts to connect to local forwarded ports.-L
- Specifies that the given port on the local (client) host is to be forwarded to the given host and port on the remote side. Format: [bind_address:]port:host:hostport
Example above was modeled for usage with PostgreSQL. Here's how you can start a standard psql console:
psql postgres://user:password@127.0.0.0.1:5000/database
Implementation
Go standard library has plenty of packages, but unfortunately it does not provide
package to deal with SSH stuff. There's a "third-party" package crypto/ssh
maintained
by Google (docs):
go get golang.org/x/crypto/ssh
Implementing ssh port forwarding programmatically takes a few steps:
- Establish SSH connection with remote server user pubkey or password authentication
- Make connection to the target ip:port from SSH connection
- Start a local server on port
- Accept local connections and forward data to the remote connection
Here's simplified code that does just that:
package main
import (
"io"
"io/ioutil"
"log"
"net"
"os"
"golang.org/x/crypto/ssh"
)
// Get default location of a private key
func privateKeyPath() string {
return os.Getenv("HOME") + "/.ssh/id_rsa"
}
// Get private key for ssh authentication
func parsePrivateKey(keyPath string) (ssh.Signer, error) {
buff, _ := ioutil.ReadFile(keyPath)
return ssh.ParsePrivateKey(buff)
}
// Get ssh client config for our connection
// SSH config will use 2 authentication strategies: by key and by password
func makeSshConfig(user, password string) (*ssh.ClientConfig, error) {
key, err := parsePrivateKey(privateKeyPath())
if err != nil {
return nil, err
}
config := ssh.ClientConfig{
User: user,
Auth: []ssh.AuthMethod{
ssh.PublicKeys(key),
ssh.Password(password),
},
}
return &config, nil
}
// Handle local client connections and tunnel data to the remote serverq
// Will use io.Copy - http://golang.org/pkg/io/#Copy
func handleClient(client net.Conn, remote net.Conn) {
defer client.Close()
chDone := make(chan bool)
// Start remote -> local data transfer
go func() {
_, err := io.Copy(client, remote)
if err != nil {
log.Println("error while copy remote->local:", err)
}
chDone <- true
}()
// Start local -> remote data transfer
go func() {
_, err := io.Copy(remote, client)
if err != nil {
log.Println(err)
}
chDone <- true
}()
<-chDone
}
func main() {
// Connection settings
sshAddr := "remote_ip:22"
localAddr := "127.0.0.1:5000"
remoteAddr := "127.0.0.1:5432"
// Build SSH client configuration
cfg, err := makeSshConfig("user", "password")
if err != nil {
log.Fatalln(err)
}
// Establish connection with SSH server
conn, err := ssh.Dial("tcp", sshAddr, cfg)
if err != nil {
log.Fatalln(err)
}
defer conn.Close()
// Establish connection with remote server
remote, err := conn.Dial("tcp", remoteAddr)
if err != nil {
log.Fatalln(err)
}
// Start local server to forward traffic to remote connection
local, err := net.Listen("tcp", localAddr)
if err != nil {
log.Fatalln(err)
}
defer local.Close()
// Handle incoming connections
for {
client, err := local.Accept()
if err != nil {
log.Fatalln(err)
}
handleClient(client, remote)
}
}
The code above does not need much explanation except for the io.Copy(dst, src)
call which does all the magic. It copies from src to dst until either EOF is
reached on src or an error occurs.
Example works, however there are few issues with it: a) it does not handle concurrency
well and b) is not stable as using just ssh
command. I have to dig into the
problem a bit more to fully understand what's happening. To clarify, concurrency
issue only appears when using io.Copy
via SSH connection, using it as local
port forwarded works just fine.