diff --git a/src/color/tcolor.go b/src/color/tcolor.go new file mode 100644 index 0000000..5d92142 --- /dev/null +++ b/src/color/tcolor.go @@ -0,0 +1,84 @@ +package color + +import ( + "fmt" + "runtime" +) + +const ( + TextBlack = iota + 30 + TextRed + TextGreen + TextYellow + TextBlue + TextMagenta + TextCyan + TextWhite +) + +func Black(str string) string { + return textColor(TextBlack, str) +} + +func Red(str string) string { + return textColor(TextRed, str) +} + +func Green(str string) string { + return textColor(TextGreen, str) +} + +func Yellow(str string) string { + return textColor(TextYellow, str) +} + +func Blue(str string) string { + return textColor(TextBlue, str) +} + +func Magenta(str string) string { + return textColor(TextMagenta, str) +} + +func Cyan(str string) string { + return textColor(TextCyan, str) +} + +func White(str string) string { + return textColor(TextWhite, str) +} + +func textColor(color int, str string) string { + if IsWindows() { + return str + } + + switch color { + case TextBlack: + return fmt.Sprintf("\x1b[0;%dm%s\x1b[0m", TextBlack, str) + case TextRed: + return fmt.Sprintf("\x1b[0;%dm%s\x1b[0m", TextRed, str) + case TextGreen: + return fmt.Sprintf("\x1b[0;%dm%s\x1b[0m", TextGreen, str) + case TextYellow: + return fmt.Sprintf("\x1b[0;%dm%s\x1b[0m", TextYellow, str) + case TextBlue: + return fmt.Sprintf("\x1b[0;%dm%s\x1b[0m", TextBlue, str) + case TextMagenta: + return fmt.Sprintf("\x1b[0;%dm%s\x1b[0m", TextMagenta, str) + case TextCyan: + return fmt.Sprintf("\x1b[0;%dm%s\x1b[0m", TextCyan, str) + case TextWhite: + return fmt.Sprintf("\x1b[0;%dm%s\x1b[0m", TextWhite, str) + default: + return str + } +} + +func IsWindows() bool { + if runtime.GOOS == "windows" { + return true + } else { + return false + } +} diff --git a/src/config/config.go b/src/config/config.go new file mode 100644 index 0000000..45794e5 --- /dev/null +++ b/src/config/config.go @@ -0,0 +1,56 @@ +package config + +import "crypto/tls" + +type Cfg struct { + Port *string + Raddr *string + Log *string + Monitor *bool + Tls *bool +} + +type TlsConfig struct { + PrivateKeyFile string + CertFile string + Organization string + CommonName string + ServerTLSConfig *tls.Config +} + +func NewTlsConfig(pk, cert, org, cn string) *TlsConfig { + return &TlsConfig{ + PrivateKeyFile: pk, + CertFile: cert, + Organization: org, + CommonName: cn, + ServerTLSConfig: &tls.Config{ + CipherSuites: []uint16{ + tls.TLS_RSA_WITH_RC4_128_SHA, + tls.TLS_RSA_WITH_3DES_EDE_CBC_SHA, + tls.TLS_RSA_WITH_AES_128_CBC_SHA, + tls.TLS_RSA_WITH_AES_256_CBC_SHA, + tls.TLS_RSA_WITH_AES_128_CBC_SHA256, + tls.TLS_RSA_WITH_AES_128_GCM_SHA256, + tls.TLS_RSA_WITH_AES_256_GCM_SHA384, + tls.TLS_ECDHE_ECDSA_WITH_RC4_128_SHA, + tls.TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA, + tls.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA, + tls.TLS_ECDHE_RSA_WITH_RC4_128_SHA, + tls.TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA, + tls.TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA, + tls.TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA, + tls.TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256, + tls.TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256, + tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, + tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, + tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, + tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, + tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305, + tls.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305, + tls.TLS_FALLBACK_SCSV, + }, + PreferServerCipherSuites: true, + }, + } +} diff --git a/src/main.go b/src/main.go new file mode 100644 index 0000000..9d7d0bf --- /dev/null +++ b/src/main.go @@ -0,0 +1,74 @@ +package main + +import ( + "config" + "flag" + "io" + "io/ioutil" + "mitm" + "mylog" + "net/http" + "os" + "sync" +) + +//func main() { +// +// println("Hello, Running Success!") +//} + +func main() { + var log io.WriteCloser + var err error + // cofig + conf := new(conf.Cfg) + conf.Port = flag.String("port", "8080", "端口") + conf.Raddr = flag.String("raddr", "", "远程地址") + conf.Log = flag.String("logFile", "", "本地日志路径") + conf.Monitor = flag.Bool("m", false, "监听模式(捕获内容)") + conf.Tls = flag.Bool("tls", false, "是否启用tls连接") + //pool := flag.String("pool", "", "代理池地址,格式 ip:端口 ") + + flag.Parse() + + // init log + if *conf.Log != "" { + log, err = os.Create(*conf.Log) + if err != nil { + mylog.Fatalln("创建日志文件失败 " + err.Error()) + } + } else { + log = os.Stderr + } + mylog.SetLog(log) + + // add current node to remote pool + //go regToPool(*pool,*conf.Port,*conf.Tls) + + // init tls config + tlsConfig := config.NewTlsConfig("gomitmproxy-ca-pk.pem", "gomitmproxy-ca-cert.pem", "", "") + // start mitm proxy + wg := new(sync.WaitGroup) + wg.Add(1) + mitm.Gomitmproxy(conf, tlsConfig, wg) + wg.Wait() +} + +// 注册自己到代理池中心 +func regToPool(pool string, port string, proto bool) { + mylog.Println("开始注册到代理池 ") + protoStr := "http" + if proto { + protoStr = "https" + } + pollURL := "http://" + pool + "/v2/add?port=" + port + "&proto=" + protoStr + mylog.Println("url: %s ", pollURL) + resp, err := http.Get(pollURL) + if err != nil { + mylog.Println("出现错误 ", err.Error()) + } + defer resp.Body.Close() + bodyBytes, err := ioutil.ReadAll(resp.Body) + mylog.Println("代理池返回结果:" + string(bodyBytes)) + +} diff --git a/src/main/main.go b/src/main/main.go deleted file mode 100644 index 13f451e..0000000 --- a/src/main/main.go +++ /dev/null @@ -1,44 +0,0 @@ -package main - -import ( - "config" - "flag" - "io" - "mitm" - "mylog" - "os" - "sync" -) - -func main() { - var log io.WriteCloser - var err error - // cofig - conf := new(config.Cfg) - conf.Port = flag.String("port", "8080", "Listen port") - conf.Raddr = flag.String("raddr", "", "Remote addr") - conf.Log = flag.String("logFile", "", "log file path") - conf.Monitor = flag.Bool("m", false, "monitor mode") - conf.Tls = flag.Bool("tls", false, "tls connect") - - flag.Parse() - - // init log - if *conf.Log != "" { - log, err = os.Create(*conf.Log) - if err != nil { - mylog.Fatalln("fail to create log file " + err.Error()) - } - } else { - log = os.Stderr - } - mylog.SetLog(log) - - // init tls config - tlsConfig := config.NewTlsConfig("gomitmproxy-ca-pk.pem", "gomitmproxy-ca-cert.pem", "", "") - // start mitm proxy - wg := new(sync.WaitGroup) - wg.Add(1) - mitm.Gomitmproxy(conf, tlsConfig, wg) - wg.Wait() -} diff --git a/src/main/main.go.bak b/src/main/main.go.bak new file mode 100644 index 0000000..9d7d0bf --- /dev/null +++ b/src/main/main.go.bak @@ -0,0 +1,74 @@ +package main + +import ( + "config" + "flag" + "io" + "io/ioutil" + "mitm" + "mylog" + "net/http" + "os" + "sync" +) + +//func main() { +// +// println("Hello, Running Success!") +//} + +func main() { + var log io.WriteCloser + var err error + // cofig + conf := new(conf.Cfg) + conf.Port = flag.String("port", "8080", "端口") + conf.Raddr = flag.String("raddr", "", "远程地址") + conf.Log = flag.String("logFile", "", "本地日志路径") + conf.Monitor = flag.Bool("m", false, "监听模式(捕获内容)") + conf.Tls = flag.Bool("tls", false, "是否启用tls连接") + //pool := flag.String("pool", "", "代理池地址,格式 ip:端口 ") + + flag.Parse() + + // init log + if *conf.Log != "" { + log, err = os.Create(*conf.Log) + if err != nil { + mylog.Fatalln("创建日志文件失败 " + err.Error()) + } + } else { + log = os.Stderr + } + mylog.SetLog(log) + + // add current node to remote pool + //go regToPool(*pool,*conf.Port,*conf.Tls) + + // init tls config + tlsConfig := config.NewTlsConfig("gomitmproxy-ca-pk.pem", "gomitmproxy-ca-cert.pem", "", "") + // start mitm proxy + wg := new(sync.WaitGroup) + wg.Add(1) + mitm.Gomitmproxy(conf, tlsConfig, wg) + wg.Wait() +} + +// 注册自己到代理池中心 +func regToPool(pool string, port string, proto bool) { + mylog.Println("开始注册到代理池 ") + protoStr := "http" + if proto { + protoStr = "https" + } + pollURL := "http://" + pool + "/v2/add?port=" + port + "&proto=" + protoStr + mylog.Println("url: %s ", pollURL) + resp, err := http.Get(pollURL) + if err != nil { + mylog.Println("出现错误 ", err.Error()) + } + defer resp.Body.Close() + bodyBytes, err := ioutil.ReadAll(resp.Body) + mylog.Println("代理池返回结果:" + string(bodyBytes)) + +} diff --git a/src/mitm/cache.go b/src/mitm/cache.go new file mode 100644 index 0000000..d2c8022 --- /dev/null +++ b/src/mitm/cache.go @@ -0,0 +1,48 @@ +// package cache implements a really primitive cache that associates expiring +// values with string keys. This cache never clears itself out. +package mitm + +import ( + "sync" + "time" +) + +// Cache is a cache for binary data +type Cache struct { + entries map[string]*entry + mutex sync.RWMutex +} + +// entry is an entry in a Cache +type entry struct { + data interface{} + expiration time.Time +} + +// NewCache creates a new Cache +func NewCache() *Cache { + return &Cache{entries: make(map[string]*entry)} +} + +// Get returns the currently cached value for the given key, as long as it +// hasn't expired. If the key was never set, or has expired, found will be +// false. +func (cache *Cache) Get(key string) (val interface{}, found bool) { + cache.mutex.RLock() + defer cache.mutex.RUnlock() + entry := cache.entries[key] + if entry == nil { + return nil, false + } else if entry.expiration.Before(time.Now()) { + return nil, false + } else { + return entry.data, true + } +} + +// Set sets a value in the cache with an expiration of now + ttl. +func (cache *Cache) Set(key string, data interface{}, ttl time.Duration) { + cache.mutex.Lock() + defer cache.mutex.Unlock() + cache.entries[key] = &entry{data, time.Now().Add(ttl)} +} diff --git a/src/mitm/dump.go b/src/mitm/dump.go new file mode 100644 index 0000000..6887eff --- /dev/null +++ b/src/mitm/dump.go @@ -0,0 +1,100 @@ +package mitm + +import ( + "bufio" + "bytes" + "color" + "compress/flate" + "compress/gzip" + "fmt" + "io" + "io/ioutil" + "math" + "mylog" + "net/http" + "strconv" +) + +func httpDump(reqDump []byte, resp *http.Response) { + defer resp.Body.Close() + var respStatusStr string + respStatus := resp.StatusCode + respStatusHeader := int(math.Floor(float64(respStatus / 100))) + switch respStatusHeader { + case 2: + respStatusStr = color.Green("<--" + strconv.Itoa(respStatus)) + case 3: + respStatusStr = color.Yellow("<--" + strconv.Itoa(respStatus)) + case 4: + respStatusStr = color.Magenta("<--" + strconv.Itoa(respStatus)) + case 5: + respStatusStr = color.Red("<--" + strconv.Itoa(respStatus)) + } + + fmt.Println(color.Green("Request:"), respStatusStr) + req, _ := ParseReq(reqDump) + fmt.Printf("%s %s %s\n", color.Blue(req.Method), req.Host+req.RequestURI, respStatusStr) + fmt.Printf("%s %s\n", color.Blue("RemoteAddr:"), req.RemoteAddr) + for headerName, headerContext := range req.Header { + fmt.Printf("%s: %s\n", color.Blue(headerName), headerContext) + } + + if req.Method == "POST" { + fmt.Println(color.Green("POST Param:")) + err := req.ParseForm() + if err != nil { + mylog.Println("parseForm error:", err) + } else { + for k, v := range req.Form { + fmt.Printf("\t%s: %s\n", color.Blue(k), v) + } + } + } + fmt.Println(color.Green("Response:")) + for headerName, headerContext := range resp.Header { + fmt.Printf("%s: %s\n", color.Blue(headerName), headerContext) + } + + respBody, err := ioutil.ReadAll(resp.Body) + if err != nil { + mylog.Println("func httpDump read resp body err:", err) + } else { + acceptEncode := resp.Header["Content-Encoding"] + var respBodyBin bytes.Buffer + w := bufio.NewWriter(&respBodyBin) + w.Write(respBody) + w.Flush() + for _, compress := range acceptEncode { + switch compress { + case "gzip": + r, err := gzip.NewReader(&respBodyBin) + if err != nil { + mylog.Println("gzip reader err:", err) + } else { + defer r.Close() + respBody, _ = ioutil.ReadAll(r) + } + break + case "deflate": + r := flate.NewReader(&respBodyBin) + defer r.Close() + respBody, _ = ioutil.ReadAll(r) + break + } + } + fmt.Printf("%s\n", string(respBody)) + } + + fmt.Printf("%s%s%s\n", color.Black("####################"), color.Cyan("END"), color.Black("####################")) +} + +func ParseReq(b []byte) (*http.Request, error) { + // func ReadRequest(b *bufio.Reader) (req *Request, err error) { return readRequest(b, deleteHostHeader) } + fmt.Println(string(b)) + fmt.Println("-----------------------") + var buf io.ReadWriter + buf = new(bytes.Buffer) + buf.Write(b) + bufr := bufio.NewReader(buf) + return http.ReadRequest(bufr) +} diff --git a/src/mitm/gen_key.go b/src/mitm/gen_key.go new file mode 100644 index 0000000..d83e076 --- /dev/null +++ b/src/mitm/gen_key.go @@ -0,0 +1,310 @@ +// Package keyman provides convenience APIs around Go's built-in crypto APIs. +package mitm + +import ( + "crypto/rand" + "crypto/rsa" + "crypto/x509" + "crypto/x509/pkix" + "encoding/pem" + "fmt" + "io/ioutil" + "math/big" + "mylog" + "net" + "os" + "time" +) + +const ( + PEM_HEADER_PRIVATE_KEY = "RSA PRIVATE KEY" + PEM_HEADER_PUBLIC_KEY = "RSA PRIVATE KEY" + PEM_HEADER_CERTIFICATE = "CERTIFICATE" +) + +var ( + tenYearsFromToday = time.Now().AddDate(10, 0, 0) +) + +// PrivateKey is a convenience wrapper for rsa.PrivateKey +type PrivateKey struct { + rsaKey *rsa.PrivateKey +} + +// Certificate is a convenience wrapper for x509.Certificate +type Certificate struct { + cert *x509.Certificate + derBytes []byte +} + +/******************************************************************************* + * Private Key Functions + ******************************************************************************/ + +// GeneratePK generates a PrivateKey with a specified size in bits. +func GeneratePK(bits int) (key *PrivateKey, err error) { + var rsaKey *rsa.PrivateKey + rsaKey, err = rsa.GenerateKey(rand.Reader, bits) + if err == nil { + key = &PrivateKey{rsaKey: rsaKey} + } + return +} + +// LoadPKFromFile loads a PEM-encoded PrivateKey from a file +func LoadPKFromFile(filename string) (key *PrivateKey, err error) { + privateKeyData, err := ioutil.ReadFile(filename) + if err != nil { + if os.IsNotExist(err) { + return nil, err + } + return nil, fmt.Errorf("Unable to read private key file from file %s: %s", filename, err) + } + block, _ := pem.Decode(privateKeyData) + if block == nil { + return nil, fmt.Errorf("Unable to decode PEM encoded private key data: %s", err) + } + rsaKey, err := x509.ParsePKCS1PrivateKey(block.Bytes) + if err != nil { + return nil, fmt.Errorf("Unable to decode X509 private key data: %s", err) + } + return &PrivateKey{rsaKey: rsaKey}, nil +} + +// PEMEncoded encodes the PrivateKey in PEM +func (key *PrivateKey) PEMEncoded() (pemBytes []byte) { + return pem.EncodeToMemory(key.pemBlock()) +} + +// WriteToFile writes the PEM-encoded PrivateKey to the given file +func (key *PrivateKey) WriteToFile(filename string) (err error) { + keyOut, err := os.OpenFile(filename, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600) + if err != nil { + return fmt.Errorf("Failed to open %s for writing: %s", filename, err) + } + if err := pem.Encode(keyOut, key.pemBlock()); err != nil { + return fmt.Errorf("Unable to PEM encode private key: %s", err) + } + if err := keyOut.Close(); err != nil { + mylog.Printf("Unable to close file: %v", err) + } + return +} + +func (key *PrivateKey) pemBlock() *pem.Block { + return &pem.Block{Type: PEM_HEADER_PRIVATE_KEY, Bytes: x509.MarshalPKCS1PrivateKey(key.rsaKey)} +} + +/******************************************************************************* + * Certificate Functions + ******************************************************************************/ + +/* +Certificate() generates a certificate for the Public Key of the given PrivateKey +based on the given template and signed by the given issuer. If issuer is nil, +the generated certificate is self-signed. +*/ +func (key *PrivateKey) Certificate(template *x509.Certificate, issuer *Certificate) (*Certificate, error) { + return key.CertificateForKey(template, issuer, &key.rsaKey.PublicKey) +} + +/* +CertificateForKey() generates a certificate for the given Public Key based on +the given template and signed by the given issuer. If issuer is nil, the +generated certificate is self-signed. +*/ +func (key *PrivateKey) CertificateForKey(template *x509.Certificate, issuer *Certificate, publicKey interface{}) (*Certificate, error) { + var issuerCert *x509.Certificate + if issuer == nil { + // Note - for self-signed certificates, we include the host's external IP address + issuerCert = template + } else { + issuerCert = issuer.cert + } + derBytes, err := x509.CreateCertificate( + rand.Reader, // secure entropy + template, // the template for the new cert + issuerCert, // cert that's signing this cert + publicKey, // public key + key.rsaKey, // private key + ) + if err != nil { + return nil, err + } + return bytesToCert(derBytes) +} + +// TLSCertificateFor generates a certificate useful for TLS use based on the +// given parameters. These certs are usable for key encipherment and digital +// signatures. +// +// organization: the org name for the cert. +// name: used as the common name for the cert. If name is an IP +// address, it is also added as an IP SAN. +// validUntil: time at which certificate expires +// isCA: whether or not this cert is a CA +// issuer: the certificate which is issuing the new cert. If nil, the +// new cert will be a self-signed CA certificate. +// +func (key *PrivateKey) TLSCertificateFor( + organization string, + name string, + validUntil time.Time, + isCA bool, + issuer *Certificate) (cert *Certificate, err error) { + + template := &x509.Certificate{ + SerialNumber: new(big.Int).SetInt64(int64(time.Now().UnixNano())), + Subject: pkix.Name{ + Organization: []string{organization}, + CommonName: name, + }, + NotBefore: time.Now().AddDate(0, -1, 0), + NotAfter: validUntil, + + BasicConstraintsValid: true, + KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature, + } + + // If name is an ip address, add it as an IP SAN + ip := net.ParseIP(name) + if ip != nil { + template.IPAddresses = []net.IP{ip} + } + + isSelfSigned := issuer == nil + if isSelfSigned { + template.ExtKeyUsage = []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth, x509.ExtKeyUsageClientAuth} + } + + // If it's a CA, add certificate signing + if isCA { + template.KeyUsage = template.KeyUsage | x509.KeyUsageCertSign + template.IsCA = true + } + + cert, err = key.Certificate(template, issuer) + return +} + +// LoadCertificateFromFile loads a Certificate from a PEM-encoded file +func LoadCertificateFromFile(filename string) (*Certificate, error) { + certificateData, err := ioutil.ReadFile(filename) + if err != nil { + if os.IsNotExist(err) { + return nil, err + } + return nil, fmt.Errorf("Unable to read certificate file from disk: %s", err) + } + return LoadCertificateFromPEMBytes(certificateData) +} + +// LoadCertificateFromPEMBytes loads a Certificate from a byte array in PEM +// format +func LoadCertificateFromPEMBytes(pemBytes []byte) (*Certificate, error) { + block, _ := pem.Decode(pemBytes) + if block == nil { + return nil, fmt.Errorf("Unable to decode PEM encoded certificate") + } + return bytesToCert(block.Bytes) +} + +// LoadCertificateFromX509 loads a Certificate from an x509.Certificate +func LoadCertificateFromX509(cert *x509.Certificate) (*Certificate, error) { + pemBytes := pem.EncodeToMemory(&pem.Block{ + Type: "CERTIFICATE", + Headers: nil, + Bytes: cert.Raw, + }) + return LoadCertificateFromPEMBytes(pemBytes) +} + +// X509 returns the x509 certificate underlying this Certificate +func (cert *Certificate) X509() *x509.Certificate { + return cert.cert +} + +// PEMEncoded encodes the Certificate in PEM +func (cert *Certificate) PEMEncoded() (pemBytes []byte) { + return pem.EncodeToMemory(cert.pemBlock()) +} + +// WriteToFile writes the PEM-encoded Certificate to a file. +func (cert *Certificate) WriteToFile(filename string) (err error) { + certOut, err := os.Create(filename) + if err != nil { + return fmt.Errorf("Failed to open %s for writing: %s", filename, err) + } + defer func() { + if err := certOut.Close(); err != nil { + mylog.Printf("Unable to close file: %v", err) + } + }() + return pem.Encode(certOut, cert.pemBlock()) +} + +func (cert *Certificate) WriteToTempFile() (name string, err error) { + // Create a temp file containing the certificate + tempFile, err := ioutil.TempFile("", "tempCert") + if err != nil { + return "", fmt.Errorf("Unable to create temp file: %s", err) + } + name = tempFile.Name() + err = cert.WriteToFile(name) + if err != nil { + return "", fmt.Errorf("Unable to save certificate to temp file: %s", err) + } + return +} + +// WriteToDERFile writes the DER-encoded Certificate to a file. +func (cert *Certificate) WriteToDERFile(filename string) (err error) { + certOut, err := os.Create(filename) + if err != nil { + return fmt.Errorf("Failed to open %s for writing: %s", filename, err) + } + defer func() { + if err := certOut.Close(); err != nil { + mylog.Printf("Unable to close file: %v", err) + } + }() + _, err = certOut.Write(cert.derBytes) + return err +} + +// PoolContainingCert creates a pool containing this cert. +func (cert *Certificate) PoolContainingCert() *x509.CertPool { + pool := x509.NewCertPool() + pool.AddCert(cert.cert) + return pool +} + +// PoolContainingCerts constructs a CertPool containing all of the given certs +// (PEM encoded). +func PoolContainingCerts(certs ...string) (*x509.CertPool, error) { + pool := x509.NewCertPool() + for _, cert := range certs { + c, err := LoadCertificateFromPEMBytes([]byte(cert)) + if err != nil { + return nil, err + } + pool.AddCert(c.cert) + } + return pool, nil +} + +func (cert *Certificate) ExpiresBefore(time time.Time) bool { + return cert.cert.NotAfter.Before(time) +} + +func bytesToCert(derBytes []byte) (*Certificate, error) { + cert, err := x509.ParseCertificate(derBytes) + if err != nil { + return nil, err + } + return &Certificate{cert, derBytes}, nil +} + +func (cert *Certificate) pemBlock() *pem.Block { + return &pem.Block{Type: PEM_HEADER_CERTIFICATE, Bytes: cert.derBytes} +} diff --git a/src/mitm/gomitmproxy.go b/src/mitm/gomitmproxy.go new file mode 100644 index 0000000..e158c18 --- /dev/null +++ b/src/mitm/gomitmproxy.go @@ -0,0 +1,45 @@ +// This example shows a proxy server that uses go-mitm to man-in-the-middle +// HTTPS connections opened with CONNECT requests + +package mitm + +import ( + "config" + "mylog" + "net/http" + "sync" + "time" +) + +func Gomitmproxy(conf *config.Cfg, tlsConfig *config.TlsConfig, wg *sync.WaitGroup) { + handler, err := InitConfig(conf, tlsConfig) + if err != nil { + mylog.Fatalf("InitConfig error: %s", err) + } + + server := &http.Server{ + Addr: ":" + *conf.Port, + Handler: handler, + ReadTimeout: 1 * time.Hour, + WriteTimeout: 1 * time.Hour, + } + + go func() { + mylog.Printf("Gomitmproxy Listening On: %s", *conf.Port) + if *conf.Tls { + mylog.Println("Listen And Serve HTTP TLS") + err = server.ListenAndServeTLS("gomitmproxy-ca-cert.pem", "gomitmproxy-ca-pk.pem") + } else { + mylog.Println("Listen And Serve HTTP") + err = server.ListenAndServe() + } + if err != nil { + mylog.Fatalf("Unable To Start HTTP proxy: %s", err) + } + + wg.Done() + mylog.Printf("Gomitmproxy Stop!!!!") + }() + + return +} diff --git a/src/mitm/listener.go b/src/mitm/listener.go new file mode 100644 index 0000000..c1d14a5 --- /dev/null +++ b/src/mitm/listener.go @@ -0,0 +1,28 @@ +package mitm + +import ( + "io" + "net" +) + +type mitmListener struct { + conn net.Conn +} + +func (listener *mitmListener) Accept() (net.Conn, error) { + if listener.conn != nil { + conn := listener.conn + listener.conn = nil + return conn, nil + } else { + return nil, io.EOF + } +} + +func (listener *mitmListener) Close() error { + return nil +} + +func (listener *mitmListener) Addr() net.Addr { + return nil +} diff --git a/src/mitm/mitm.go b/src/mitm/mitm.go new file mode 100644 index 0000000..aef8d87 --- /dev/null +++ b/src/mitm/mitm.go @@ -0,0 +1,393 @@ +package mitm + +import ( + "bufio" + "config" + "crypto/tls" + "errors" + "fmt" + "io" + "log" + "mylog" + "net" + "net/http" + "net/http/httputil" + "net/url" + "regexp" + "strings" + "sync" + "time" +) + +const ( + Version = "1.1" + ONE_DAY = 24 * time.Hour + TWO_WEEKS = ONE_DAY * 14 + ONE_MONTH = 1 + ONE_YEAR = 1 +) + +type HandlerWrapper struct { + MyConfig *config.Cfg + tlsConfig *config.TlsConfig + wrapped http.Handler + pk *PrivateKey + pkPem []byte + issuingCert *Certificate + issuingCertPem []byte + serverTLSConfig *tls.Config + dynamicCerts *Cache + certMutex sync.Mutex + https bool +} + +func (hw *HandlerWrapper) GenerateCertForClient() (err error) { + if hw.tlsConfig.Organization == "" { + hw.tlsConfig.Organization = "gomitmproxy" + Version + } + if hw.tlsConfig.CommonName == "" { + hw.tlsConfig.CommonName = "gomitmproxy" + } + if hw.pk, err = LoadPKFromFile(hw.tlsConfig.PrivateKeyFile); err != nil { + hw.pk, err = GeneratePK(2048) + if err != nil { + return fmt.Errorf("Unable to generate private key: %s", err) + } + hw.pk.WriteToFile(hw.tlsConfig.PrivateKeyFile) + } + hw.pkPem = hw.pk.PEMEncoded() + hw.issuingCert, err = LoadCertificateFromFile(hw.tlsConfig.CertFile) + if err != nil || hw.issuingCert.ExpiresBefore(time.Now().AddDate(0, ONE_MONTH, 0)) { + hw.issuingCert, err = hw.pk.TLSCertificateFor( + hw.tlsConfig.Organization, + hw.tlsConfig.CommonName, + time.Now().AddDate(ONE_YEAR, 0, 0), + true, + nil) + if err != nil { + return fmt.Errorf("Unable to generate self-signed issuing certificate: %s", err) + } + hw.issuingCert.WriteToFile(hw.tlsConfig.CertFile) + } + hw.issuingCertPem = hw.issuingCert.PEMEncoded() + return +} + +func (hw *HandlerWrapper) FakeCertForName(name string) (cert *tls.Certificate, err error) { + kpCandidateIf, found := hw.dynamicCerts.Get(name) + if found { + return kpCandidateIf.(*tls.Certificate), nil + } + + hw.certMutex.Lock() + defer hw.certMutex.Unlock() + kpCandidateIf, found = hw.dynamicCerts.Get(name) + if found { + return kpCandidateIf.(*tls.Certificate), nil + } + + //create certificate + certTTL := TWO_WEEKS + generatedCert, err := hw.pk.TLSCertificateFor( + hw.tlsConfig.Organization, + name, + time.Now().Add(certTTL), + false, + hw.issuingCert) + if err != nil { + return nil, fmt.Errorf("Unable to issue certificate: %s", err) + } + keyPair, err := tls.X509KeyPair(generatedCert.PEMEncoded(), hw.pkPem) + if err != nil { + return nil, fmt.Errorf("Unable to parse keypair for tls: %s", err) + } + + cacheTTL := certTTL - ONE_DAY + hw.dynamicCerts.Set(name, &keyPair, cacheTTL) + return &keyPair, nil +} + +func (hw *HandlerWrapper) DumpHTTPAndHTTPs(resp http.ResponseWriter, req *http.Request) { + mylog.Println("DumpHTTPAndHTTPs") + req.Header.Del("Proxy-Connection") + req.Header.Set("Connection", "Keep-Alive") + var reqDump []byte + var err error + ch := make(chan bool) + // handle connection + go func() { + reqDump, err = httputil.DumpRequestOut(req, true) + ch <- true + }() + if err != nil { + mylog.Println("DumpRequest error ", err) + } + connIn, _, err := resp.(http.Hijacker).Hijack() + if err != nil { + mylog.Println("hijack error:", err) + } + defer connIn.Close() + + var respOut *http.Response + host := req.Host + + matched, _ := regexp.MatchString(":[0-9]+$", host) + + if !hw.https { + if !matched { + host += ":80" + } + + connOut, err := net.DialTimeout("tcp", host, time.Second*30) + if err != nil { + mylog.Println("dial to", host, "error:", err) + return + } + + if err = req.Write(connOut); err != nil { + mylog.Println("send to server error", err) + return + } + + respOut, err = http.ReadResponse(bufio.NewReader(connOut), req) + if err != nil && err != io.EOF { + mylog.Println("read response error:", err) + } + + } else { + if !matched { + host += ":443" + } + + connOut, err := tls.Dial("tcp", host, hw.tlsConfig.ServerTLSConfig) + if err != nil { + + } + if err = req.Write(connOut); err != nil { + mylog.Println("tls dial to", host, "error:", err) + return + } + if err = req.Write(connOut); err != nil { + mylog.Println("send to server error", err) + return + } + + respOut, err = http.ReadResponse(bufio.NewReader(connOut), req) + if err != nil && err != io.EOF { + mylog.Println("read response error:", err) + } + + } + + if respOut == nil { + log.Println("respOut is nil") + return + } + + respDump, err := httputil.DumpResponse(respOut, true) + if err != nil { + mylog.Println("respDump error:", err) + } + + _, err = connIn.Write(respDump) + if err != nil { + mylog.Println("connIn write error:", err) + } + + if *hw.MyConfig.Monitor { + <-ch + go httpDump(reqDump, respOut) + } else { + <-ch + } +} + +func (hw *HandlerWrapper) ServeHTTP(resp http.ResponseWriter, req *http.Request) { + raddr := *hw.MyConfig.Raddr + if len(raddr) != 0 { + hw.Forward(resp, req, raddr) + } else { + if req.Method == "CONNECT" { + hw.https = true + hw.InterceptHTTPs(resp, req) + } else { + hw.https = false + hw.DumpHTTPAndHTTPs(resp, req) + } + } +} + +func (hw *HandlerWrapper) InterceptHTTPs(resp http.ResponseWriter, req *http.Request) { + mylog.Println("InterceptHTTPs") + addr := req.Host + host := strings.Split(addr, ":")[0] + + cert, err := hw.FakeCertForName(host) + if err != nil { + msg := fmt.Sprintf("Could not get mitm cert for name: %s\nerror: %s", host, err) + respBadGateway(resp, msg) + return + } + + // handle connection + connIn, _, err := resp.(http.Hijacker).Hijack() + if err != nil { + msg := fmt.Sprintf("Unable to access underlying connection from client: %s", err) + respBadGateway(resp, msg) + return + } + tlsConfig := copyTlsConfig(hw.tlsConfig.ServerTLSConfig) + tlsConfig.Certificates = []tls.Certificate{*cert} + tlsConnIn := tls.Server(connIn, tlsConfig) + listener := &mitmListener{tlsConnIn} + handler := http.HandlerFunc(func(resp2 http.ResponseWriter, req2 *http.Request) { + req2.URL.Scheme = "https" + req2.URL.Host = req2.Host + hw.DumpHTTPAndHTTPs(resp2, req2) + + }) + + go func() { + err = http.Serve(listener, handler) + if err != nil && err != io.EOF { + mylog.Printf("Error serving mitm'ed connection: %s", err) + } + }() + + connIn.Write([]byte("HTTP/1.1 200 OK\r\n\r\n")) +} + +func (hw *HandlerWrapper) Forward(resp http.ResponseWriter, req *http.Request, raddr string) { + connIn, _, err := resp.(http.Hijacker).Hijack() + connOut, err := net.Dial("tcp", raddr) + if err != nil { + mylog.Println("dial tcp error", err) + } + + err = connectProxyServer(connOut, raddr) + if err != nil { + mylog.Println("connectProxyServer error:", err) + } + + if req.Method == "CONNECT" { + b := []byte("HTTP/1.1 200 Connection Established\r\n" + + "Proxy-Agent: gomitmproxy/" + Version + "\r\n" + + "Content-Length: 0" + "\r\n\r\n") + _, err := connIn.Write(b) + if err != nil { + mylog.Println("Write Connect err:", err) + return + } + } else { + req.Header.Del("Proxy-Connection") + req.Header.Set("Connection", "Keep-Alive") + if err = req.Write(connOut); err != nil { + mylog.Println("send to server err", err) + return + } + } + err = Transport(connIn, connOut) + if err != nil { + mylog.Println("trans error ", err) + } +} + +func InitConfig(conf *config.Cfg, tlsConfig *config.TlsConfig) (*HandlerWrapper, error) { + hw := &HandlerWrapper{ + MyConfig: conf, + tlsConfig: tlsConfig, + dynamicCerts: NewCache(), + } + err := hw.GenerateCertForClient() + if err != nil { + return nil, err + } + return hw, nil +} + +func copyTlsConfig(template *tls.Config) *tls.Config { + tlsConfig := &tls.Config{} + if template != nil { + *tlsConfig = *template + } + return tlsConfig +} + +func copyHTTPRequest(template *http.Request) *http.Request { + req := &http.Request{} + if template != nil { + *req = *template + } + return req +} + +func respBadGateway(resp http.ResponseWriter, msg string) { + log.Println(msg) + resp.WriteHeader(502) + resp.Write([]byte(msg)) +} + +//两个io口的连接 +func Transport(conn1, conn2 net.Conn) (err error) { + rChan := make(chan error, 1) + wChan := make(chan error, 1) + + go MyCopy(conn1, conn2, wChan) + go MyCopy(conn2, conn1, rChan) + + select { + case err = <-wChan: + case err = <-rChan: + } + + return +} + +func MyCopy(src io.Reader, dst io.Writer, ch chan<- error) { + _, err := io.Copy(dst, src) + ch <- err +} + +func connectProxyServer(conn net.Conn, addr string) error { + req := &http.Request{ + Method: "CONNECT", + URL: &url.URL{Host: addr}, + Host: addr, + ProtoMajor: 1, + ProtoMinor: 1, + Header: make(http.Header), + } + req.Header.Set("Proxy-Connection", "keep-alive") + + if err := req.Write(conn); err != nil { + return err + } + + resp, err := http.ReadResponse(bufio.NewReader(conn), req) + if err != nil { + return err + } + if resp.StatusCode != http.StatusOK { + return errors.New(resp.Status) + } + return nil +} + +/*func ReadNotDrain(r *http.Request) (content []byte, err error) { + content, err = ioutil.ReadAll(r.Body) + r.Body = io.ReadCloser(bytes.NewBuffer(content)) + return +} + +func ParsePostValues(req *http.Request) (url.Values, error) { + c, err := ReadNotDrain(req) + if err != nil { + return nil, err + } + values, err := url.ParseQuery(string(c)) + if err != nil { + return nil, err + } + return values, nil +} +*/ diff --git a/src/mylog/my_log.go b/src/mylog/my_log.go new file mode 100644 index 0000000..91523e4 --- /dev/null +++ b/src/mylog/my_log.go @@ -0,0 +1,34 @@ +package mylog + +import "log" +import "io" + +var logger *log.Logger + +func init() { + log.SetFlags(log.LstdFlags | log.Lshortfile) +} + +func SetLog(l io.WriteCloser) { + logger = log.New(l, "[gomitmproxy]", log.LstdFlags) +} + +func Fatalf(format string, v ...interface{}) { + logger.Fatalf(format, v) +} + +func Fatalln(v ...interface{}) { + logger.Fatalln(v) +} + +func Printf(format string, v ...interface{}) { + logger.Printf(format, v) +} + +func Println(v ...interface{}) { + logger.Println(v) +} + +func Panicln(v ...interface{}) { + logger.Panicln(v) +} diff --git a/src/mylog/my_log_test.go b/src/mylog/my_log_test.go new file mode 100644 index 0000000..e3698d1 --- /dev/null +++ b/src/mylog/my_log_test.go @@ -0,0 +1,17 @@ +package mylog + +import ( + "log" + "os" + "testing" +) + +func TestMyLog(t *testing.T) { + logFile, err := os.Create("test.log") + if err != nil { + log.Fatalln("fail to create log file!") + } + logger := log.New(logFile, "[gomitmproxy]", log.LstdFlags|log.Llongfile) + SetLog(logger) + Println("log test") +} diff --git a/src/mylog/test.log b/src/mylog/test.log new file mode 100644 index 0000000..2a76a76 --- /dev/null +++ b/src/mylog/test.log @@ -0,0 +1 @@ +[gomitmproxy]2017/04/03 00:00:25 /Users/bao/program/go/gowork/gomitmproxy/src/vendor/mylog/my_log.go:20: [log test]