package util import ( "bytes" "context" "io" "net/http" "net/http/httptest" "net/url" "strings" "gitea.ckfah.com/cjjy/gocommon/pkg/net/engines" ) const ( defaultHttpRequestMethod = http.MethodPost defaultHttpRequestPath = "/" ) var ( middleWareHandlerFunc []engines.HandlerFunc handlerRequestHeaders []func(ctx context.Context, header http.Header) ) /** 添加全局的Middle ware */ func AddMiddleWareHandlerFunc(fun engines.HandlerFunc) { if fun == nil { return } middleWareHandlerFunc = append(middleWareHandlerFunc, fun) } /** 添加全局的Header处理器 */ func AddHandlerRequestHeaders(fun func(ctx context.Context, header http.Header)) { if fun == nil { return } handlerRequestHeaders = append(handlerRequestHeaders, fun) } var ( PostMethodOption Option = func(request *TestHttpRequest) { request.Method = http.MethodPost } GetMethodOption Option = func(request *TestHttpRequest) { request.Method = http.MethodGet } RequestHeaderOption = func(handler func(ctx context.Context, header http.Header)) Option { return func(request *TestHttpRequest) { request.handlerRequestHeaders = append(request.handlerRequestHeaders, handler) } } AddFormParamOption = func(key string, value string) Option { return func(request *TestHttpRequest) { request.params.Set(key, value) } } JsonBodyOption = func(jsonBody string) Option { return func(request *TestHttpRequest) { request.RequestBody = strings.NewReader(jsonBody) } } MultipartFormDataOption = func(params map[string]string, fileTypeParams []FileForm) (Option, error) { reader, headerValue, err := NewMultiFormRequest(params, fileTypeParams) if err != nil { return nil, err } return func(request *TestHttpRequest) { option := RequestHeaderOption(func(ctx context.Context, header http.Header) { header.Set("Content-Type", headerValue) }) option(request) request.RequestBody = reader }, nil } HandlerFuncOption = func(handler engines.HandlerFunc) Option { return func(request *TestHttpRequest) { request.handlerFunctions = append(request.handlerFunctions, handler) } } WithContextOption = func(ctx context.Context) Option { return func(request *TestHttpRequest) { request.Ctx = ctx } } ) type Option func(request *TestHttpRequest) type TestHttpRequest struct { Ctx context.Context Method string Path string RequestBody io.Reader handlerRequestHeaders []func(ctx context.Context, header http.Header) header http.Header handlerFunctions []engines.HandlerFunc params url.Values } /** Post &&\ Json */ func NewTestHttpRequestWithDefault(jsonBody string, handler ...engines.HandlerFunc) (*http.Response, error) { return NewTestPostHttpRequest(jsonBody, handler...) } func NewTestPostHttpRequest(jsonBody string, handler ...engines.HandlerFunc) (*http.Response, error) { params := make([]Option, 0, len(handler)+1) params = append(params, JsonBodyOption(jsonBody)) for _, elem := range handler { params = append(params, HandlerFuncOption(elem)) } return NewTestHttpRequest(params...).MockHttpRequest() } /** Post &&\ multipart/form-data */ func NewTestHttpRequestWithUploadFile(requestParams map[string]string, fileTypeParams []FileForm, handler ...engines.HandlerFunc) (*http.Response, error) { params := make([]Option, 0, len(handler)+1) option, err := MultipartFormDataOption(requestParams, fileTypeParams) if err != nil { return nil, err } params = append(params, option) for _, elem := range handler { params = append(params, HandlerFuncOption(elem)) } return NewTestHttpRequest(params...).MockHttpRequest() } /** GET */ func NewTestHttpRequestWithDownLoadFile(requestParams map[string]string, handler ...engines.HandlerFunc) (*http.Response, error) { return NewTestGetHttpRequest(requestParams, handler...) } func NewTestGetHttpRequest(requestParams map[string]string, handler ...engines.HandlerFunc) (*http.Response, error) { params := make([]Option, 0, len(handler)+1+len(requestParams)) for key, value := range requestParams { params = append(params, AddFormParamOption(key, value)) } params = append(params, GetMethodOption) for _, elem := range handler { params = append(params, HandlerFuncOption(elem)) } return NewTestHttpRequest(params...).MockHttpRequest() } func NewTestHttpRequest(ops ...Option) *TestHttpRequest { request := new(TestHttpRequest) request.Path = defaultHttpRequestPath request.Method = defaultHttpRequestMethod request.Ctx = context.Background() request.params = map[string][]string{} request.handlerFunctions = append([]engines.HandlerFunc{}, middleWareHandlerFunc...) request.handlerRequestHeaders = append([]func(ctx context.Context, header http.Header){}, handlerRequestHeaders...) for _, elem := range ops { elem(request) } return request } func (request *TestHttpRequest) MockHttpRequest() (*http.Response, error) { mux := http.NewServeMux() engine := engines.New() engine.Handle(request.Method, request.Path, request.handlerFunctions...) mux.Handle(request.Path, engine) r, err := http.NewRequestWithContext(request.Ctx, request.Method, request.Path, request.RequestBody) if err != nil { return nil, err } for _, elem := range request.handlerRequestHeaders { elem(request.Ctx, r.Header) } r.URL.RawQuery = request.params.Encode() w := httptest.NewRecorder() mux.ServeHTTP(w, r) return w.Result(), nil } func newRequestBody() io.ReadWriter { return &bytes.Buffer{} }