|
|
|
package main
|
|
|
|
|
|
|
|
import (
|
|
|
|
"bufio"
|
|
|
|
"crypto/tls"
|
|
|
|
"errors"
|
|
|
|
"fmt"
|
|
|
|
"io"
|
|
|
|
"log"
|
|
|
|
"net"
|
|
|
|
"net/http"
|
|
|
|
"net/http/httputil"
|
|
|
|
"net/url"
|
|
|
|
"regexp"
|
|
|
|
"strings"
|
|
|
|
"sync"
|
|
|
|
"time"
|
|
|
|
)
|
|
|
|
|
|
|
|
type HandlerWrapper struct {
|
|
|
|
MyConfig *Cfg
|
|
|
|
tlsConfig *TlsConfig
|
|
|
|
wrapped http.Handler
|
|
|
|
pk *PrivateKey
|
|
|
|
pkPem []byte
|
|
|
|
issuingCert *Certificate
|
|
|
|
issuingCertPem []byte
|
|
|
|
serverTLSConfig *tls.Config
|
|
|
|
dynamicCerts *Cache
|
|
|
|
certMutex sync.Mutex
|
|
|
|
https bool
|
|
|
|
}
|
|
|
|
|
|
|
|
func (hw *HandlerWrapper) GenerateCertForClient() (err error) {
|
|
|
|
if hw.tlsConfig.Organization == "" {
|
|
|
|
hw.tlsConfig.Organization = "gomitmproxy" + Version
|
|
|
|
}
|
|
|
|
if hw.tlsConfig.CommonName == "" {
|
|
|
|
hw.tlsConfig.CommonName = "gomitmproxy"
|
|
|
|
}
|
|
|
|
if hw.pk, err = LoadPKFromFile(hw.tlsConfig.PrivateKeyFile); err != nil {
|
|
|
|
hw.pk, err = GeneratePK(2048)
|
|
|
|
if err != nil {
|
|
|
|
return fmt.Errorf("Unable to generate private key: %s", err)
|
|
|
|
}
|
|
|
|
hw.pk.WriteToFile(hw.tlsConfig.PrivateKeyFile)
|
|
|
|
}
|
|
|
|
hw.pkPem = hw.pk.PEMEncoded()
|
|
|
|
hw.issuingCert, err = LoadCertificateFromFile(hw.tlsConfig.CertFile)
|
|
|
|
if err != nil || hw.issuingCert.ExpiresBefore(time.Now().AddDate(0, ONE_MONTH, 0)) {
|
|
|
|
hw.issuingCert, err = hw.pk.TLSCertificateFor(
|
|
|
|
hw.tlsConfig.Organization,
|
|
|
|
hw.tlsConfig.CommonName,
|
|
|
|
time.Now().AddDate(ONE_YEAR, 0, 0),
|
|
|
|
true,
|
|
|
|
nil)
|
|
|
|
if err != nil {
|
|
|
|
return fmt.Errorf("Unable to generate self-signed issuing certificate: %s", err)
|
|
|
|
}
|
|
|
|
hw.issuingCert.WriteToFile(hw.tlsConfig.CertFile)
|
|
|
|
}
|
|
|
|
hw.issuingCertPem = hw.issuingCert.PEMEncoded()
|
|
|
|
return
|
|
|
|
}
|
|
|
|
|
|
|
|
func (hw *HandlerWrapper) FakeCertForName(name string) (cert *tls.Certificate, err error) {
|
|
|
|
kpCandidateIf, found := hw.dynamicCerts.Get(name)
|
|
|
|
if found {
|
|
|
|
return kpCandidateIf.(*tls.Certificate), nil
|
|
|
|
}
|
|
|
|
|
|
|
|
hw.certMutex.Lock()
|
|
|
|
defer hw.certMutex.Unlock()
|
|
|
|
kpCandidateIf, found = hw.dynamicCerts.Get(name)
|
|
|
|
if found {
|
|
|
|
return kpCandidateIf.(*tls.Certificate), nil
|
|
|
|
}
|
|
|
|
|
|
|
|
//create certificate
|
|
|
|
certTTL := TWO_WEEKS
|
|
|
|
generatedCert, err := hw.pk.TLSCertificateFor(
|
|
|
|
hw.tlsConfig.Organization,
|
|
|
|
name,
|
|
|
|
time.Now().Add(certTTL),
|
|
|
|
false,
|
|
|
|
hw.issuingCert)
|
|
|
|
if err != nil {
|
|
|
|
return nil, fmt.Errorf("Unable to issue certificate: %s", err)
|
|
|
|
}
|
|
|
|
keyPair, err := tls.X509KeyPair(generatedCert.PEMEncoded(), hw.pkPem)
|
|
|
|
if err != nil {
|
|
|
|
return nil, fmt.Errorf("Unable to parse keypair for tls: %s", err)
|
|
|
|
}
|
|
|
|
|
|
|
|
cacheTTL := certTTL - ONE_DAY
|
|
|
|
hw.dynamicCerts.Set(name, &keyPair, cacheTTL)
|
|
|
|
return &keyPair, nil
|
|
|
|
}
|
|
|
|
|
|
|
|
func (hw *HandlerWrapper) DumpHTTPAndHTTPs(resp http.ResponseWriter, req *http.Request) {
|
|
|
|
req.Header.Del("Proxy-Connection")
|
|
|
|
req.Header.Set("Connection", "Keep-Alive")
|
|
|
|
|
|
|
|
// handle connection
|
|
|
|
connIn, _, err := resp.(http.Hijacker).Hijack()
|
|
|
|
if err != nil {
|
|
|
|
logger.Println("hijack error:", err)
|
|
|
|
}
|
|
|
|
defer connIn.Close()
|
|
|
|
|
|
|
|
var respOut *http.Response
|
|
|
|
host := req.Host
|
|
|
|
|
|
|
|
matched, _ := regexp.MatchString(":[0-9]+$", host)
|
|
|
|
|
|
|
|
if !hw.https {
|
|
|
|
if !matched {
|
|
|
|
host += ":80"
|
|
|
|
}
|
|
|
|
|
|
|
|
connOut, err := net.DialTimeout("tcp", host, time.Second*30)
|
|
|
|
if err != nil {
|
|
|
|
logger.Println("dial to", host, "error:", err)
|
|
|
|
}
|
|
|
|
respOut, err = http.ReadResponse(bufio.NewReader(connOut), req)
|
|
|
|
if err != nil && err != io.EOF {
|
|
|
|
logger.Println("read response error:", err)
|
|
|
|
}
|
|
|
|
|
|
|
|
} else {
|
|
|
|
if !matched {
|
|
|
|
host += ":443"
|
|
|
|
}
|
|
|
|
|
|
|
|
connOut, err := tls.Dial("tcp", host, hw.tlsConfig.ServerTLSConfig)
|
|
|
|
|
|
|
|
if err != nil {
|
|
|
|
logger.Panicln("tls dial to", host, "error:", err)
|
|
|
|
}
|
|
|
|
if err = req.Write(connOut); err != nil {
|
|
|
|
logger.Println("send to server error", err)
|
|
|
|
}
|
|
|
|
|
|
|
|
respOut, err = http.ReadResponse(bufio.NewReader(connOut), req)
|
|
|
|
if err != nil && err != io.EOF {
|
|
|
|
logger.Println("read response error:", err)
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
if respOut == nil {
|
|
|
|
log.Println("respOut is nil")
|
|
|
|
return
|
|
|
|
}
|
|
|
|
|
|
|
|
respDump, err := httputil.DumpResponse(respOut, true)
|
|
|
|
if err != nil {
|
|
|
|
logger.Println("respDump error:", err)
|
|
|
|
}
|
|
|
|
|
|
|
|
_, err = connIn.Write(respDump)
|
|
|
|
if err != nil {
|
|
|
|
logger.Println("connIn write error:", err)
|
|
|
|
}
|
|
|
|
|
|
|
|
if *hw.MyConfig.Monitor {
|
|
|
|
go httpDump(req, respOut)
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
func (hw *HandlerWrapper) ServeHTTP(resp http.ResponseWriter, req *http.Request) {
|
|
|
|
|
|
|
|
raddr := *hw.MyConfig.Raddr
|
|
|
|
if len(raddr) != 0 {
|
|
|
|
hw.Forward(resp, req, raddr)
|
|
|
|
} else {
|
|
|
|
if req.Method == "CONNECT" {
|
|
|
|
hw.https = true
|
|
|
|
hw.InterceptHTTPs(resp, req)
|
|
|
|
} else {
|
|
|
|
hw.DumpHTTPAndHTTPs(resp, req)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
func (hw *HandlerWrapper) InterceptHTTPs(resp http.ResponseWriter, req *http.Request) {
|
|
|
|
addr := req.Host
|
|
|
|
host := strings.Split(addr, ":")[0]
|
|
|
|
|
|
|
|
cert, err := hw.FakeCertForName(host)
|
|
|
|
if err != nil {
|
|
|
|
msg := fmt.Sprintf("Could not get mitm cert for name: %s\nerror: %s", host, err)
|
|
|
|
respBadGateway(resp, msg)
|
|
|
|
return
|
|
|
|
}
|
|
|
|
|
|
|
|
// handle connection
|
|
|
|
connIn, _, err := resp.(http.Hijacker).Hijack()
|
|
|
|
if err != nil {
|
|
|
|
msg := fmt.Sprintf("Unable to access underlying connection from client: %s", err)
|
|
|
|
respBadGateway(resp, msg)
|
|
|
|
return
|
|
|
|
}
|
|
|
|
tlsConfig := copyTlsConfig(hw.tlsConfig.ServerTLSConfig)
|
|
|
|
tlsConfig.Certificates = []tls.Certificate{*cert}
|
|
|
|
tlsConnIn := tls.Server(connIn, tlsConfig)
|
|
|
|
listener := &mitmListener{tlsConnIn}
|
|
|
|
handler := http.HandlerFunc(func(resp2 http.ResponseWriter, req2 *http.Request) {
|
|
|
|
req2.URL.Scheme = "https"
|
|
|
|
req2.URL.Host = req2.Host
|
|
|
|
hw.DumpHTTPAndHTTPs(resp2, req2)
|
|
|
|
|
|
|
|
})
|
|
|
|
|
|
|
|
go func() {
|
|
|
|
err = http.Serve(listener, handler)
|
|
|
|
if err != nil && err != io.EOF {
|
|
|
|
logger.Printf("Error serving mitm'ed connection: %s", err)
|
|
|
|
}
|
|
|
|
}()
|
|
|
|
|
|
|
|
connIn.Write([]byte("HTTP/1.1 200 OK\r\n\r\n"))
|
|
|
|
}
|
|
|
|
|
|
|
|
func (hw *HandlerWrapper) Forward(resp http.ResponseWriter, req *http.Request, raddr string) {
|
|
|
|
connIn, _, err := resp.(http.Hijacker).Hijack()
|
|
|
|
connOut, err := net.Dial("tcp", raddr)
|
|
|
|
if err != nil {
|
|
|
|
logger.Println("dial tcp error", err)
|
|
|
|
}
|
|
|
|
|
|
|
|
err = connectProxyServer(connOut, raddr)
|
|
|
|
if err != nil {
|
|
|
|
logger.Println("connectProxyServer error:", err)
|
|
|
|
}
|
|
|
|
|
|
|
|
if req.Method == "CONNECT" {
|
|
|
|
b := []byte("HTTP/1.1 200 Connection Established\r\n" +
|
|
|
|
"Proxy-Agent: gomitmproxy/" + Version + "\r\n\r\n")
|
|
|
|
_, err := connIn.Write(b)
|
|
|
|
if err != nil {
|
|
|
|
logger.Println("Write Connect err:", err)
|
|
|
|
return
|
|
|
|
}
|
|
|
|
} else {
|
|
|
|
req.Header.Del("Proxy-Connection")
|
|
|
|
req.Header.Set("Connection", "Keep-Alive")
|
|
|
|
if err = req.Write(connOut); err != nil {
|
|
|
|
logger.Println("send to server err", err)
|
|
|
|
return
|
|
|
|
}
|
|
|
|
}
|
|
|
|
err = Transport(connIn, connOut)
|
|
|
|
if err != nil {
|
|
|
|
log.Println("trans error ", err)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
func InitConfig(conf *Cfg, tlsConfig *TlsConfig) (*HandlerWrapper, error) {
|
|
|
|
hw := &HandlerWrapper{
|
|
|
|
MyConfig: conf,
|
|
|
|
tlsConfig: tlsConfig,
|
|
|
|
dynamicCerts: NewCache(),
|
|
|
|
}
|
|
|
|
err := hw.GenerateCertForClient()
|
|
|
|
if err != nil {
|
|
|
|
return nil, err
|
|
|
|
}
|
|
|
|
return hw, nil
|
|
|
|
}
|
|
|
|
|
|
|
|
func copyTlsConfig(template *tls.Config) *tls.Config {
|
|
|
|
tlsConfig := &tls.Config{}
|
|
|
|
if template != nil {
|
|
|
|
*tlsConfig = *template
|
|
|
|
}
|
|
|
|
return tlsConfig
|
|
|
|
}
|
|
|
|
|
|
|
|
func respBadGateway(resp http.ResponseWriter, msg string) {
|
|
|
|
log.Println(msg)
|
|
|
|
resp.WriteHeader(502)
|
|
|
|
resp.Write([]byte(msg))
|
|
|
|
}
|
|
|
|
|
|
|
|
//两个io口的连接
|
|
|
|
func Transport(conn1, conn2 net.Conn) (err error) {
|
|
|
|
rChan := make(chan error, 1)
|
|
|
|
wChan := make(chan error, 1)
|
|
|
|
|
|
|
|
go MyCopy(conn1, conn2, wChan)
|
|
|
|
go MyCopy(conn2, conn1, rChan)
|
|
|
|
|
|
|
|
select {
|
|
|
|
case err = <-wChan:
|
|
|
|
case err = <-rChan:
|
|
|
|
}
|
|
|
|
|
|
|
|
return
|
|
|
|
}
|
|
|
|
|
|
|
|
func MyCopy(src io.Reader, dst io.Writer, ch chan<- error) {
|
|
|
|
_, err := io.Copy(dst, src)
|
|
|
|
ch <- err
|
|
|
|
}
|
|
|
|
|
|
|
|
func connectProxyServer(conn net.Conn, addr string) error {
|
|
|
|
|
|
|
|
req := &http.Request{
|
|
|
|
Method: "CONNECT",
|
|
|
|
URL: &url.URL{Host: addr},
|
|
|
|
Host: addr,
|
|
|
|
ProtoMajor: 1,
|
|
|
|
ProtoMinor: 1,
|
|
|
|
Header: make(http.Header),
|
|
|
|
}
|
|
|
|
req.Header.Set("Proxy-Connection", "keep-alive")
|
|
|
|
|
|
|
|
if err := req.Write(conn); err != nil {
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
|
|
|
|
resp, err := http.ReadResponse(bufio.NewReader(conn), req)
|
|
|
|
if err != nil {
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
if resp.StatusCode != http.StatusOK {
|
|
|
|
return errors.New(resp.Status)
|
|
|
|
}
|
|
|
|
return nil
|
|
|
|
}
|