Skip to content

Commit 8de87be

Browse files
chore (refactoring): cleanup server middleware
1 parent 7139d00 commit 8de87be

File tree

6 files changed

+176
-166
lines changed

6 files changed

+176
-166
lines changed

server/ctrl/webdav.go

+8-6
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
11
package ctrl
22

33
import (
4-
. "github.com/mickael-kerjean/filestash/server/common"
5-
"github.com/mickael-kerjean/filestash/server/model"
6-
"github.com/mickael-kerjean/net/webdav"
74
"net/http"
85
"path/filepath"
96
"strings"
7+
8+
. "github.com/mickael-kerjean/filestash/server/common"
9+
"github.com/mickael-kerjean/filestash/server/middleware"
10+
"github.com/mickael-kerjean/filestash/server/model"
11+
"github.com/mickael-kerjean/net/webdav"
1012
)
1113

1214
func WebdavHandler(ctx *App, res http.ResponseWriter, req *http.Request) {
@@ -53,8 +55,8 @@ func WebdavHandler(ctx *App, res http.ResponseWriter, req *http.Request) {
5355
* an imbecile and considering we can't even see the source code they are running, the best approach we
5456
* could go on is: "crap in, crap out" where useless request coming in are identified and answer appropriatly
5557
*/
56-
func WebdavBlacklist(fn func(*App, http.ResponseWriter, *http.Request)) func(ctx *App, res http.ResponseWriter, req *http.Request) {
57-
return func(ctx *App, res http.ResponseWriter, req *http.Request) {
58+
func WebdavBlacklist(fn middleware.HandlerFunc) middleware.HandlerFunc {
59+
return middleware.HandlerFunc(func(ctx *App, res http.ResponseWriter, req *http.Request) {
5860
base := filepath.Base(req.URL.String())
5961

6062
if req.Method == "PUT" || req.Method == "MKCOL" {
@@ -125,5 +127,5 @@ func WebdavBlacklist(fn func(*App, http.ResponseWriter, *http.Request)) func(ctx
125127
}
126128
}
127129
fn(ctx, res, req)
128-
}
130+
})
129131
}

server/middleware/context.go

+3-3
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ import (
99
"strings"
1010
)
1111

12-
func BodyParser(fn func(*App, http.ResponseWriter, *http.Request)) func(ctx *App, res http.ResponseWriter, req *http.Request) {
12+
func BodyParser(fn HandlerFunc) HandlerFunc {
1313
extractBody := func(req *http.Request) (map[string]interface{}, error) {
1414
body := map[string]interface{}{}
1515
byt, err := ioutil.ReadAll(req.Body)
@@ -25,14 +25,14 @@ func BodyParser(fn func(*App, http.ResponseWriter, *http.Request)) func(ctx *App
2525
return body, nil
2626
}
2727

28-
return func(ctx *App, res http.ResponseWriter, req *http.Request) {
28+
return HandlerFunc(func(ctx *App, res http.ResponseWriter, req *http.Request) {
2929
var err error
3030
if ctx.Body, err = extractBody(req); err != nil {
3131
SendErrorResult(res, ErrNotValid)
3232
return
3333
}
3434
fn(ctx, res, req)
35-
}
35+
})
3636
}
3737

3838
func GenerateRequestID(prefix string) string {

server/middleware/http.go

+21-21
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,8 @@ import (
1010
"strings"
1111
)
1212

13-
func ApiHeaders(fn func(*App, http.ResponseWriter, *http.Request)) func(ctx *App, res http.ResponseWriter, req *http.Request) {
14-
return func(ctx *App, res http.ResponseWriter, req *http.Request) {
13+
func ApiHeaders(fn HandlerFunc) HandlerFunc {
14+
return HandlerFunc(func(ctx *App, res http.ResponseWriter, req *http.Request) {
1515
header := res.Header()
1616
header.Set("Content-Type", "application/json")
1717
header.Set("Cache-Control", "no-cache")
@@ -20,20 +20,20 @@ func ApiHeaders(fn func(*App, http.ResponseWriter, *http.Request)) func(ctx *App
2020
header.Set("X-Request-ID", GenerateRequestID("API"))
2121
}
2222
fn(ctx, res, req)
23-
}
23+
})
2424
}
2525

26-
func StaticHeaders(fn func(*App, http.ResponseWriter, *http.Request)) func(ctx *App, res http.ResponseWriter, req *http.Request) {
27-
return func(ctx *App, res http.ResponseWriter, req *http.Request) {
26+
func StaticHeaders(fn HandlerFunc) HandlerFunc {
27+
return HandlerFunc(func(ctx *App, res http.ResponseWriter, req *http.Request) {
2828
header := res.Header()
2929
header.Set("Content-Type", GetMimeType(filepath.Ext(req.URL.Path)))
3030
header.Set("Cache-Control", "max-age=2592000")
3131
fn(ctx, res, req)
32-
}
32+
})
3333
}
3434

35-
func IndexHeaders(fn func(*App, http.ResponseWriter, *http.Request)) func(ctx *App, res http.ResponseWriter, req *http.Request) {
36-
return func(ctx *App, res http.ResponseWriter, req *http.Request) {
35+
func IndexHeaders(fn HandlerFunc) HandlerFunc {
36+
return HandlerFunc(func(ctx *App, res http.ResponseWriter, req *http.Request) {
3737
header := res.Header()
3838
header.Set("Content-Type", "text/html")
3939
header.Set("Cache-Control", "no-cache")
@@ -65,23 +65,23 @@ func IndexHeaders(fn func(*App, http.ResponseWriter, *http.Request)) func(ctx *A
6565
}
6666
// header.Set("Content-Security-Policy", cspHeader)
6767
fn(ctx, res, req)
68-
}
68+
})
6969
}
7070

71-
func SecureHeaders(fn func(*App, http.ResponseWriter, *http.Request)) func(ctx *App, res http.ResponseWriter, req *http.Request) {
72-
return func(ctx *App, res http.ResponseWriter, req *http.Request) {
71+
func SecureHeaders(fn HandlerFunc) HandlerFunc {
72+
return HandlerFunc(func(ctx *App, res http.ResponseWriter, req *http.Request) {
7373
header := res.Header()
7474
if Config.Get("general.force_ssl").Bool() {
7575
header.Set("Strict-Transport-Security", "max-age=31536000; includeSubDomains; preload")
7676
}
7777
header.Set("X-Content-Type-Options", "nosniff")
7878
header.Set("X-XSS-Protection", "1; mode=block")
7979
fn(ctx, res, req)
80-
}
80+
})
8181
}
8282

83-
func SecureOrigin(fn func(*App, http.ResponseWriter, *http.Request)) func(ctx *App, res http.ResponseWriter, req *http.Request) {
84-
return func(ctx *App, res http.ResponseWriter, req *http.Request) {
83+
func SecureOrigin(fn HandlerFunc) HandlerFunc {
84+
return HandlerFunc(func(ctx *App, res http.ResponseWriter, req *http.Request) {
8585
if host := Config.Get("general.host").String(); host != "" {
8686
host = strings.TrimPrefix(host, "http://")
8787
host = strings.TrimPrefix(host, "https://")
@@ -105,11 +105,11 @@ func SecureOrigin(fn func(*App, http.ResponseWriter, *http.Request)) func(ctx *A
105105

106106
Log.Warning("Intrusion detection: %s - %s", RetrievePublicIp(req), req.URL.String())
107107
SendErrorResult(res, ErrNotAllowed)
108-
}
108+
})
109109
}
110110

111-
func WithPublicAPI(fn func(*App, http.ResponseWriter, *http.Request)) func(ctx *App, res http.ResponseWriter, req *http.Request) {
112-
return func(ctx *App, res http.ResponseWriter, req *http.Request) {
111+
func WithPublicAPI(fn HandlerFunc) HandlerFunc {
112+
return HandlerFunc(func(ctx *App, res http.ResponseWriter, req *http.Request) {
113113
apiKey := req.URL.Query().Get("key")
114114
if apiKey == "" {
115115
fn(ctx, res, req)
@@ -132,13 +132,13 @@ func WithPublicAPI(fn func(*App, http.ResponseWriter, *http.Request)) func(ctx *
132132
return
133133
}
134134
fn(ctx, res, req)
135-
}
135+
})
136136
}
137137

138138
var limiter = rate.NewLimiter(10, 1000)
139139

140-
func RateLimiter(fn func(*App, http.ResponseWriter, *http.Request)) func(ctx *App, res http.ResponseWriter, req *http.Request) {
141-
return func(ctx *App, res http.ResponseWriter, req *http.Request) {
140+
func RateLimiter(fn HandlerFunc) HandlerFunc {
141+
return HandlerFunc(func(ctx *App, res http.ResponseWriter, req *http.Request) {
142142
if limiter.Allow() == false {
143143
Log.Warning("middleware::http::ratelimit too many requests")
144144
SendErrorResult(
@@ -148,7 +148,7 @@ func RateLimiter(fn func(*App, http.ResponseWriter, *http.Request)) func(ctx *Ap
148148
return
149149
}
150150
fn(ctx, res, req)
151-
}
151+
})
152152
}
153153

154154
func EnableCors(req *http.Request, res http.ResponseWriter, host string) error {

server/middleware/index.go

+4-118
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,13 @@
11
package middleware
22

33
import (
4-
"bytes"
5-
"encoding/json"
64
. "github.com/mickael-kerjean/filestash/server/common"
75
"net/http"
8-
"sync"
96
"time"
107
)
118

12-
var telemetry = Telemetry{Data: make([]LogEntry, 0)}
9+
type HandlerFunc func(*App, http.ResponseWriter, *http.Request)
10+
type Middleware func(HandlerFunc) HandlerFunc
1311

1412
func init() {
1513
Hooks.Register.Onload(func() {
@@ -22,10 +20,7 @@ func init() {
2220
})
2321
}
2422

25-
type Middleware func(func(*App, http.ResponseWriter, *http.Request)) func(*App, http.ResponseWriter, *http.Request)
26-
27-
func NewMiddlewareChain(fn func(*App, http.ResponseWriter, *http.Request), m []Middleware, app App) http.HandlerFunc {
28-
23+
func NewMiddlewareChain(fn HandlerFunc, m []Middleware, app App) http.HandlerFunc {
2924
return func(res http.ResponseWriter, req *http.Request) {
3025
var resw ResponseWriter = NewResponseWriter(res)
3126
var f func(*App, http.ResponseWriter, *http.Request) = fn
@@ -37,7 +32,7 @@ func NewMiddlewareChain(fn func(*App, http.ResponseWriter, *http.Request), m []M
3732
if req.Body != nil {
3833
req.Body.Close()
3934
}
40-
go Logger(app, &resw, req)
35+
go logger(app, &resw, req)
4136
}
4237
}
4338

@@ -65,112 +60,3 @@ func (w *ResponseWriter) Write(b []byte) (int, error) {
6560
}
6661
return w.ResponseWriter.Write(b)
6762
}
68-
69-
type LogEntry struct {
70-
Host string `json:"host"`
71-
Method string `json:"method"`
72-
RequestURI string `json:"pathname"`
73-
Proto string `json:"proto"`
74-
Status int `json:"status"`
75-
Scheme string `json:"scheme"`
76-
UserAgent string `json:"userAgent"`
77-
Ip string `json:"ip"`
78-
Referer string `json:"referer"`
79-
Duration float64 `json:"responseTime"`
80-
Version string `json:"version"`
81-
Backend string `json:"backend"`
82-
Share string `json:"share"`
83-
License string `json:"license"`
84-
Session string `json:"session"`
85-
RequestID string `json:"requestID"`
86-
}
87-
88-
func Logger(ctx App, res http.ResponseWriter, req *http.Request) {
89-
if obj, ok := res.(*ResponseWriter); ok && req.RequestURI != "/about" {
90-
point := LogEntry{
91-
Version: APP_VERSION + "." + BUILD_DATE,
92-
License: LICENSE,
93-
Scheme: req.URL.Scheme,
94-
Host: req.Host,
95-
Method: req.Method,
96-
RequestURI: req.RequestURI,
97-
Proto: req.Proto,
98-
Status: obj.status,
99-
UserAgent: req.Header.Get("User-Agent"),
100-
Ip: req.RemoteAddr,
101-
Referer: req.Referer(),
102-
Duration: float64(time.Now().Sub(obj.start)) / (1000 * 1000),
103-
Backend: func() string {
104-
if ctx.Session["type"] == "" {
105-
return "null"
106-
}
107-
return ctx.Session["type"]
108-
}(),
109-
Share: func() string {
110-
if ctx.Share.Id == "" {
111-
return "null"
112-
}
113-
return ctx.Share.Id
114-
}(),
115-
Session: func() string {
116-
if ctx.Session["type"] == "" {
117-
return "null"
118-
}
119-
return GenerateID(&ctx)
120-
}(),
121-
RequestID: func() string {
122-
defer func() string {
123-
if r := recover(); r != nil {
124-
return "oops"
125-
}
126-
return "null"
127-
}()
128-
return res.Header().Get("X-Request-ID")
129-
}(),
130-
}
131-
if Config.Get("log.telemetry").Bool() {
132-
telemetry.Record(point)
133-
}
134-
if Config.Get("log.enable").Bool() {
135-
Log.Stdout("HTTP %3d %3s %6.1fms %s", point.Status, point.Method, point.Duration, point.RequestURI)
136-
}
137-
}
138-
}
139-
140-
type Telemetry struct {
141-
Data []LogEntry
142-
mu sync.Mutex
143-
}
144-
145-
func (this *Telemetry) Record(point LogEntry) {
146-
this.mu.Lock()
147-
this.Data = append(this.Data, point)
148-
this.mu.Unlock()
149-
}
150-
151-
func (this *Telemetry) Flush() {
152-
if len(this.Data) == 0 {
153-
return
154-
}
155-
this.mu.Lock()
156-
pts := this.Data
157-
this.Data = make([]LogEntry, 0)
158-
this.mu.Unlock()
159-
160-
body, err := json.Marshal(pts)
161-
if err != nil {
162-
return
163-
}
164-
r, err := http.NewRequest("POST", "https://downloads.filestash.app/event", bytes.NewReader(body))
165-
r.Header.Set("Connection", "Close")
166-
r.Header.Set("Content-Type", "application/json")
167-
r.Close = true
168-
if err != nil {
169-
return
170-
}
171-
resp, err := HTTP.Do(r)
172-
if err != nil {
173-
return
174-
}
175-
resp.Body.Close()
176-
}

0 commit comments

Comments
 (0)