diff --git a/dump.go b/dump.go index e437a9c..ca261a1 100644 --- a/dump.go +++ b/dump.go @@ -32,10 +32,6 @@ func httpDump(req *http.Request, resp *http.Response) { for headerName, headerContext := range req.Header { fmt.Printf("%s: %s\n", Blue(headerName), headerContext) } - fmt.Println(Green("Response:")) - for headerName, headerContext := range resp.Header { - fmt.Printf("%s: %s\n", Blue(headerName), headerContext) - } if req.Method == "POST" { fmt.Println(Green("URLEncoded form")) err := req.ParseForm() @@ -51,6 +47,11 @@ func httpDump(req *http.Request, resp *http.Response) { } } + fmt.Println(Green("Response:")) + for headerName, headerContext := range resp.Header { + fmt.Printf("%s: %s\n", Blue(headerName), headerContext) + } + respBody, err := ioutil.ReadAll(resp.Body) if err != nil { logger.Println("func httpDump read resp body err:", err) diff --git a/mitm.go b/mitm.go index 6a1a8f2..9f8a54e 100644 --- a/mitm.go +++ b/mitm.go @@ -1,81 +1,342 @@ package main import ( - "flag" + "bufio" + "crypto/tls" + "errors" + "fmt" + "io" "log" + "net" "net/http" - "os" + "net/http/httputil" + "net/url" + "regexp" + "strings" "sync" "time" ) -const ( - Version = "1.1" -) +type HandlerWrapper struct { + MyConfig *Cfg + tlsConfig *TlsConfig + wrapped http.Handler + pk *PrivateKey + pkPem []byte + issuingCert *Certificate + issuingCertPem []byte + serverTLSConfig *tls.Config + dynamicCerts *Cache + certMutex sync.Mutex + https bool +} -var ( - wg sync.WaitGroup -) +func (hw *HandlerWrapper) GenerateCertForClient() (err error) { + if hw.tlsConfig.Organization == "" { + hw.tlsConfig.Organization = "gomitmproxy" + Version + } + if hw.tlsConfig.CommonName == "" { + hw.tlsConfig.CommonName = "gomitmproxy" + } + if hw.pk, err = LoadPKFromFile(hw.tlsConfig.PrivateKeyFile); err != nil { + hw.pk, err = GeneratePK(2048) + if err != nil { + return fmt.Errorf("Unable to generate private key: %s", err) + } + hw.pk.WriteToFile(hw.tlsConfig.PrivateKeyFile) + } + hw.pkPem = hw.pk.PEMEncoded() + hw.issuingCert, err = LoadCertificateFromFile(hw.tlsConfig.CertFile) + if err != nil || hw.issuingCert.ExpiresBefore(time.Now().AddDate(0, ONE_MONTH, 0)) { + hw.issuingCert, err = hw.pk.TLSCertificateFor( + hw.tlsConfig.Organization, + hw.tlsConfig.CommonName, + time.Now().AddDate(ONE_YEAR, 0, 0), + true, + nil) + if err != nil { + return fmt.Errorf("Unable to generate self-signed issuing certificate: %s", err) + } + hw.issuingCert.WriteToFile(hw.tlsConfig.CertFile) + } + hw.issuingCertPem = hw.issuingCert.PEMEncoded() + return +} -var logFile *os.File -var logger *log.Logger +func (hw *HandlerWrapper) FakeCertForName(name string) (cert *tls.Certificate, err error) { + kpCandidateIf, found := hw.dynamicCerts.Get(name) + if found { + return kpCandidateIf.(*tls.Certificate), nil + } -func main() { - var conf Cfg + hw.certMutex.Lock() + defer hw.certMutex.Unlock() + kpCandidateIf, found = hw.dynamicCerts.Get(name) + if found { + return kpCandidateIf.(*tls.Certificate), nil + } - conf.Port = flag.String("port", "8080", "Listen port") - conf.Raddr = flag.String("raddr", "", "Remote addr") - conf.Log = flag.String("log", "./error.log", "log file path") - conf.Monitor = flag.Bool("m", false, "monitor mode") - conf.Tls = flag.Bool("tls", false, "tls connect") - flag.Parse() + //create certificate + certTTL := TWO_WEEKS + generatedCert, err := hw.pk.TLSCertificateFor( + hw.tlsConfig.Organization, + name, + time.Now().Add(certTTL), + false, + hw.issuingCert) + if err != nil { + return nil, fmt.Errorf("Unable to issue certificate: %s", err) + } + keyPair, err := tls.X509KeyPair(generatedCert.PEMEncoded(), hw.pkPem) + if err != nil { + return nil, fmt.Errorf("Unable to parse keypair for tls: %s", err) + } - var err error - logFile, err = os.Create(*conf.Log) + cacheTTL := certTTL - ONE_DAY + hw.dynamicCerts.Set(name, &keyPair, cacheTTL) + return &keyPair, nil +} + +func (hw *HandlerWrapper) DumpHTTPAndHTTPs(resp http.ResponseWriter, req *http.Request) { + req.Header.Del("Proxy-Connection") + req.Header.Set("Connection", "Keep-Alive") + + // handle connection + connIn, _, err := resp.(http.Hijacker).Hijack() if err != nil { - log.Fatalln("fail to create log file!") + logger.Println("hijack error:", err) } + defer connIn.Close() - logger = log.New(logFile, "[gomitmproxy]", log.LstdFlags|log.Llongfile) + var respOut *http.Response + host := req.Host - wg.Add(1) - gomitmproxy(&conf) - wg.Wait() -} + matched, _ := regexp.MatchString(":[0-9]+$", host) + + if !hw.https { + if !matched { + host += ":80" + } + + connOut, err := net.DialTimeout("tcp", host, time.Second*30) + if err != nil { + logger.Println("dial to", host, "error:", err) + return + } + + if err = req.Write(connOut); err != nil { + logger.Println("send to server error", err) + return + } + + respOut, err = http.ReadResponse(bufio.NewReader(connOut), req) + if err != nil && err != io.EOF { + logger.Println("read response error:", err) + } + + } else { + if !matched { + host += ":443" + } + + connOut, err := tls.Dial("tcp", host, hw.tlsConfig.ServerTLSConfig) + + if err != nil { + logger.Panicln("tls dial to", host, "error:", err) + return + } + if err = req.Write(connOut); err != nil { + logger.Println("send to server error", err) + return + } -func gomitmproxy(conf *Cfg) { - tlsConfig := NewTlsConfig("gomitmproxy-ca-pk.pem", "gomitmproxy-ca-cert.pem", "", "") + respOut, err = http.ReadResponse(bufio.NewReader(connOut), req) + if err != nil && err != io.EOF { + logger.Println("read response error:", err) + } + + } - handler, err := InitConfig(conf, tlsConfig) + if respOut == nil { + log.Println("respOut is nil") + return + } + + respDump, err := httputil.DumpResponse(respOut, true) if err != nil { - logger.Fatalf("InitConfig error: %s", err) + logger.Println("respDump error:", err) } - server := &http.Server{ - Addr: ":" + *conf.Port, - Handler: handler, - ReadTimeout: 1 * time.Hour, - WriteTimeout: 1 * time.Hour, + _, err = connIn.Write(respDump) + if err != nil { + logger.Println("connIn write error:", err) } - go func() { - log.Printf("proxy listening port:%s", *conf.Port) + if *hw.MyConfig.Monitor { + go httpDump(req, respOut) + } + +} - if *conf.Tls { - log.Println("ListenAndServeTLS") - err = server.ListenAndServeTLS("gomitmproxy-ca-cert.pem", "gomitmproxy-ca-pk.pem") +func (hw *HandlerWrapper) ServeHTTP(resp http.ResponseWriter, req *http.Request) { + + raddr := *hw.MyConfig.Raddr + if len(raddr) != 0 { + hw.Forward(resp, req, raddr) + } else { + if req.Method == "CONNECT" { + hw.https = true + hw.InterceptHTTPs(resp, req) } else { - log.Println("ListenAndServe") - err = server.ListenAndServe() + hw.https = false + hw.DumpHTTPAndHTTPs(resp, req) } + } +} + +func (hw *HandlerWrapper) InterceptHTTPs(resp http.ResponseWriter, req *http.Request) { + addr := req.Host + host := strings.Split(addr, ":")[0] + + cert, err := hw.FakeCertForName(host) + if err != nil { + msg := fmt.Sprintf("Could not get mitm cert for name: %s\nerror: %s", host, err) + respBadGateway(resp, msg) + return + } + + // handle connection + connIn, _, err := resp.(http.Hijacker).Hijack() + if err != nil { + msg := fmt.Sprintf("Unable to access underlying connection from client: %s", err) + respBadGateway(resp, msg) + return + } + tlsConfig := copyTlsConfig(hw.tlsConfig.ServerTLSConfig) + tlsConfig.Certificates = []tls.Certificate{*cert} + tlsConnIn := tls.Server(connIn, tlsConfig) + listener := &mitmListener{tlsConnIn} + handler := http.HandlerFunc(func(resp2 http.ResponseWriter, req2 *http.Request) { + req2.URL.Scheme = "https" + req2.URL.Host = req2.Host + hw.DumpHTTPAndHTTPs(resp2, req2) + + }) + + go func() { + err = http.Serve(listener, handler) + if err != nil && err != io.EOF { + logger.Printf("Error serving mitm'ed connection: %s", err) + } + }() + + connIn.Write([]byte("HTTP/1.1 200 OK\r\n\r\n")) +} + +func (hw *HandlerWrapper) Forward(resp http.ResponseWriter, req *http.Request, raddr string) { + connIn, _, err := resp.(http.Hijacker).Hijack() + connOut, err := net.Dial("tcp", raddr) + if err != nil { + logger.Println("dial tcp error", err) + } + + err = connectProxyServer(connOut, raddr) + if err != nil { + logger.Println("connectProxyServer error:", err) + } + + if req.Method == "CONNECT" { + b := []byte("HTTP/1.1 200 Connection Established\r\n" + + "Proxy-Agent: gomitmproxy/" + Version + "\r\n\r\n") + _, err := connIn.Write(b) if err != nil { - logger.Fatalf("Unable to start HTTP proxy: %s", err) + logger.Println("Write Connect err:", err) + return + } + } else { + req.Header.Del("Proxy-Connection") + req.Header.Set("Connection", "Keep-Alive") + if err = req.Write(connOut); err != nil { + logger.Println("send to server err", err) + return } + } + err = Transport(connIn, connOut) + if err != nil { + log.Println("trans error ", err) + } +} - wg.Done() +func InitConfig(conf *Cfg, tlsConfig *TlsConfig) (*HandlerWrapper, error) { + hw := &HandlerWrapper{ + MyConfig: conf, + tlsConfig: tlsConfig, + dynamicCerts: NewCache(), + } + err := hw.GenerateCertForClient() + if err != nil { + return nil, err + } + return hw, nil +} - log.Printf("gomitmproxy stop!!!!") - }() +func copyTlsConfig(template *tls.Config) *tls.Config { + tlsConfig := &tls.Config{} + if template != nil { + *tlsConfig = *template + } + return tlsConfig +} + +func respBadGateway(resp http.ResponseWriter, msg string) { + log.Println(msg) + resp.WriteHeader(502) + resp.Write([]byte(msg)) +} + +//两个io口的连接 +func Transport(conn1, conn2 net.Conn) (err error) { + rChan := make(chan error, 1) + wChan := make(chan error, 1) + + go MyCopy(conn1, conn2, wChan) + go MyCopy(conn2, conn1, rChan) + + select { + case err = <-wChan: + case err = <-rChan: + } return } + +func MyCopy(src io.Reader, dst io.Writer, ch chan<- error) { + _, err := io.Copy(dst, src) + ch <- err +} + +func connectProxyServer(conn net.Conn, addr string) error { + + req := &http.Request{ + Method: "CONNECT", + URL: &url.URL{Host: addr}, + Host: addr, + ProtoMajor: 1, + ProtoMinor: 1, + Header: make(http.Header), + } + req.Header.Set("Proxy-Connection", "keep-alive") + + if err := req.Write(conn); err != nil { + return err + } + + resp, err := http.ReadResponse(bufio.NewReader(conn), req) + if err != nil { + return err + } + if resp.StatusCode != http.StatusOK { + return errors.New(resp.Status) + } + return nil +}