diff --git a/client/client.go b/client/client.go index 4654212..27c9078 100644 --- a/client/client.go +++ b/client/client.go @@ -6,6 +6,7 @@ import ( "fmt" "net/http" "os" + "path/filepath" "strings" sshgit "github.com/go-git/go-git/v5/plumbing/transport/ssh" @@ -17,6 +18,7 @@ import ( "github.com/gomicro/trust" "github.com/google/go-github/github" "golang.org/x/crypto/ssh" + "golang.org/x/crypto/ssh/knownhosts" "golang.org/x/oauth2" "golang.org/x/time/rate" ) @@ -60,14 +62,17 @@ func New(cfg *config.Config) (*Client, error) { var publicKeys *sshgit.PublicKeys if cfg.Github.PrivateKey != "" && cfg.Github.Username != "" { - pem := []byte("") + pem := []byte(cfg.Github.PrivateKey) - publicKeys, err := sshgit.NewPublicKeys(cfg.Github.Username, pem, "") + publicKeys, err = sshgit.NewPublicKeys(cfg.Github.Username, pem, "") if err != nil { return nil, fmt.Errorf("public keys: %w", err) } - publicKeys.HostKeyCallback = ssh.InsecureIgnoreHostKey() + publicKeys.HostKeyCallback, err = knownHostsCallback() + if err != nil { + return nil, fmt.Errorf("known hosts: %w", err) + } } else if cfg.Github.PrivateKeyFile != "" { _, err := os.Stat(cfg.Github.PrivateKeyFile) if err != nil { @@ -79,16 +84,24 @@ func New(cfg *config.Config) (*Client, error) { return nil, fmt.Errorf("public keys file: %w", err) } - publicKeys.HostKeyCallback = ssh.InsecureIgnoreHostKey() + publicKeys.HostKeyCallback, err = knownHostsCallback() + if err != nil { + return nil, fmt.Errorf("known hosts: %w", err) + } } var pass *sshgit.Password if cfg.Github.Username != "" && cfg.Github.Token != "" { + hostKeyCallback, err := knownHostsCallback() + if err != nil { + return nil, fmt.Errorf("known hosts: %w", err) + } + pass = &sshgit.Password{ User: cfg.Github.Username, Password: cfg.Github.Token, HostKeyCallbackHelper: sshgit.HostKeyCallbackHelper{ - HostKeyCallback: ssh.InsecureIgnoreHostKey(), + HostKeyCallback: hostKeyCallback, }, } } @@ -118,6 +131,20 @@ func New(cfg *config.Config) (*Client, error) { }, nil } +func knownHostsCallback() (ssh.HostKeyCallback, error) { + usr, err := os.UserHomeDir() + if err != nil { + return nil, fmt.Errorf("home dir: %w", err) + } + + cb, err := knownhosts.New(filepath.Join(usr, ".ssh", "known_hosts")) + if err != nil { + return nil, fmt.Errorf("parse known_hosts: %w", err) + } + + return cb, nil +} + func (c *Client) GetLogins(ctx context.Context) ([]string, error) { logins := []string{}