|
| 1 | +-- |
| 2 | +-- Cross-origin Request Sharing (CORS) implementation for HAProxy Lua host |
| 3 | +-- |
| 4 | +-- CORS RFC: |
| 5 | +-- https://www.w3.org/TR/cors/ |
| 6 | +-- |
| 7 | +-- Copyright (c) 2019. Nick Ramirez <[email protected]> |
| 8 | +-- Copyright (c) 2019. HAProxy Technologies, LLC. |
| 9 | + |
| 10 | +local M={} |
| 11 | + |
| 12 | +-- Loops through array to find the given string. |
| 13 | +-- items: array of strings |
| 14 | +-- test_str: string to search for |
| 15 | +function contains(items, test_str) |
| 16 | + for _,item in pairs(items) do |
| 17 | + if item == test_str then |
| 18 | + return true |
| 19 | + end |
| 20 | + end |
| 21 | + |
| 22 | + return false |
| 23 | +end |
| 24 | + |
| 25 | +M.wildcard_origin_allowed = function(allowed_origins) |
| 26 | + if contains(allowed_origins, "*") then |
| 27 | + return "*" |
| 28 | + end |
| 29 | + return nil |
| 30 | +end |
| 31 | + |
| 32 | +M.specifies_scheme = function(s) |
| 33 | + return string.find(s, "^%a+://") ~= nil |
| 34 | +end |
| 35 | + |
| 36 | +M.specifies_generic_scheme = function(s) |
| 37 | + return string.find(s, "^//") ~= nil |
| 38 | +end |
| 39 | + |
| 40 | +M.begins_with_dot = function(s) |
| 41 | + return string.find(s, "^%.") ~= nil |
| 42 | +end |
| 43 | + |
| 44 | +M.trim = function(s) |
| 45 | + return s:gsub("%s+", "") |
| 46 | +end |
| 47 | + |
| 48 | +M.build_pattern = function(pattern) |
| 49 | + -- remove spaces |
| 50 | + pattern = M.trim(pattern) |
| 51 | + |
| 52 | + if pattern ~= nil and pattern ~= '' then |
| 53 | + -- if there is no scheme and the pattern does not begin with a dot, |
| 54 | + -- add the generic scheme to the beginning of the pattern |
| 55 | + if M.specifies_scheme(pattern) == false and M.specifies_generic_scheme(pattern) == false and M.begins_with_dot(pattern) == false then |
| 56 | + pattern = "//" .. pattern |
| 57 | + end |
| 58 | + |
| 59 | + -- escape dots and dashes in pattern |
| 60 | + pattern = pattern:gsub("([%.%-])", "%%%1") |
| 61 | + |
| 62 | + -- an asterisk for the port means allow all ports |
| 63 | + if string.find(pattern, "[:]+%*$") ~= nil then |
| 64 | + pattern = pattern:gsub("[:]+%*$", "[:]+[0-9]+") |
| 65 | + end |
| 66 | + |
| 67 | + -- append end character |
| 68 | + pattern = pattern .. "$" |
| 69 | + return pattern |
| 70 | + end |
| 71 | + |
| 72 | + return nil |
| 73 | +end |
| 74 | + |
| 75 | +-- If the given origin is found within the allowed_origins string, it is returned. Otherwise, nil is returned. |
| 76 | +-- origin: The value from the 'origin' request header |
| 77 | +-- allowed_origins: Comma-delimited list of allowed origins. (e.g. localhost,https://localhost:8080,//test.com) |
| 78 | +-- e.g. localhost : allow http(s)://localhost |
| 79 | +-- e.g. //localhost : allow http(s)://localhost |
| 80 | +-- e.g. https://mydomain.com : allow only HTTPS of mydomain.com |
| 81 | +-- e.g. http://mydomain.com : allow only HTTP of mydomain.com |
| 82 | +-- e.g. http://mydomain.com:8080 : allow only HTTP of mydomain.com from port 8080 |
| 83 | +-- e.g. //mydomain.com : allow only http(s)://mydomain.com |
| 84 | +-- e.g. .mydomain.com : allow ALL subdomains of mydomain.com from ALL source ports |
| 85 | +-- e.g. .mydomain.com:443 : allow ALL subdomains of mydomain.com from default HTTPS source port |
| 86 | +-- |
| 87 | +-- e.g. ".mydomain.com:443, //mydomain.com:443, //localhost" |
| 88 | +-- allows all subdomains and main domain of mydomain.com only for HTTPS from default HTTPS port and allows |
| 89 | +-- all HTTP and HTTPS connections from ALL source port for localhost |
| 90 | +-- |
| 91 | +M.get_allowed_origin = function(origin, allowed_origins) |
| 92 | + if origin ~= nil then |
| 93 | + -- if wildcard (*) is allowed, return it, which allows all origins |
| 94 | + wildcard_origin = M.wildcard_origin_allowed(allowed_origins) |
| 95 | + if wildcard_origin ~= nil then |
| 96 | + return wildcard_origin |
| 97 | + end |
| 98 | + |
| 99 | + for index, allowed_origin in ipairs(allowed_origins) do |
| 100 | + pattern = M.build_pattern(allowed_origin) |
| 101 | + |
| 102 | + if pattern ~= nil then |
| 103 | + if origin:match(pattern) then |
| 104 | + core.Debug("Test: " .. pattern .. ", Origin: " .. origin .. ", Match: yes") |
| 105 | + return origin |
| 106 | + else |
| 107 | + core.Debug("Test: " .. pattern .. ", Origin: " .. origin .. ", Match: no") |
| 108 | + end |
| 109 | + end |
| 110 | + end |
| 111 | + end |
| 112 | + |
| 113 | + return nil |
| 114 | +end |
| 115 | + |
| 116 | +-- Adds headers for CORS preflight request and then attaches them to the response |
| 117 | +-- after it comes back from the server. This works with versions of HAProxy prior to 2.2. |
| 118 | +-- The downside is that the OPTIONS request must be sent to the backend server first and can't |
| 119 | +-- be intercepted and returned immediately. |
| 120 | +-- txn: The current transaction object that gives access to response properties |
| 121 | +-- allowed_methods: Comma-delimited list of allowed HTTP methods. (e.g. GET,POST,PUT,DELETE) |
| 122 | +-- allowed_headers: Comma-delimited list of allowed headers. (e.g. X-Header1,X-Header2) |
| 123 | +function preflight_request_ver1(txn, allowed_methods, allowed_headers) |
| 124 | + core.Debug("CORS: preflight request received") |
| 125 | + txn.http:res_set_header("Access-Control-Allow-Methods", allowed_methods) |
| 126 | + txn.http:res_set_header("Access-Control-Allow-Headers", allowed_headers) |
| 127 | + txn.http:res_set_header("Access-Control-Max-Age", 600) |
| 128 | + core.Debug("CORS: attaching allowed methods to response") |
| 129 | +end |
| 130 | + |
| 131 | +-- Add headers for CORS preflight request and then returns a 204 response. |
| 132 | +-- The 'reply' function used here is available in HAProxy 2.2+. It allows HAProxy to return |
| 133 | +-- a reply without contacting the server. |
| 134 | +-- txn: The current transaction object that gives access to response properties |
| 135 | +-- origin: The value from the 'origin' request header |
| 136 | +-- allowed_methods: Comma-delimited list of allowed HTTP methods. (e.g. GET,POST,PUT,DELETE) |
| 137 | +-- allowed_origins: Comma-delimited list of allowed origins. (e.g. localhost,localhost:8080,test.com) |
| 138 | +-- allowed_headers: Comma-delimited list of allowed headers. (e.g. X-Header1,X-Header2) |
| 139 | +function preflight_request_ver2(txn, origin, allowed_methods, allowed_origins, allowed_headers) |
| 140 | + core.Debug("CORS: preflight request received") |
| 141 | + |
| 142 | + local reply = txn:reply() |
| 143 | + reply:set_status(204, "No Content") |
| 144 | + reply:add_header("Content-Type", "text/html") |
| 145 | + reply:add_header("Access-Control-Allow-Methods", allowed_methods) |
| 146 | + reply:add_header("Access-Control-Allow-Headers", allowed_headers) |
| 147 | + reply:add_header("Access-Control-Max-Age", 600) |
| 148 | + |
| 149 | + local allowed_origin = M.get_allowed_origin(origin, allowed_origins) |
| 150 | + |
| 151 | + if allowed_origin == nil then |
| 152 | + core.Debug("CORS: " .. origin .. " not allowed") |
| 153 | + else |
| 154 | + core.Debug("CORS: " .. origin .. " allowed") |
| 155 | + reply:add_header("Access-Control-Allow-Origin", allowed_origin) |
| 156 | + |
| 157 | + if allowed_origin ~= "*" then |
| 158 | + reply:add_header("Vary", "Accept-Encoding,Origin") |
| 159 | + end |
| 160 | + end |
| 161 | + |
| 162 | + core.Debug("CORS: Returning reply to preflight request") |
| 163 | + txn:done(reply) |
| 164 | +end |
| 165 | + |
| 166 | +-- When invoked during a request, captures the origin header if present and stores it in a private variable. |
| 167 | +-- If the request is OPTIONS and it is a supported version of HAProxy, returns a preflight request reply. |
| 168 | +-- Otherwise, the preflight request header is added to the response after it has returned from the server. |
| 169 | +-- txn: The current transaction object that gives access to response properties |
| 170 | +-- allowed_methods: Comma-delimited list of allowed HTTP methods. (e.g. GET,POST,PUT,DELETE) |
| 171 | +-- allowed_origins: Comma-delimited list of allowed origins. (e.g. localhost,localhost:8080,test.com) |
| 172 | +-- allowed_headers: Comma-delimited list of allowed headers. (e.g. X-Header1,X-Header2) |
| 173 | +function cors_request(txn, allowed_methods, allowed_origins, allowed_headers) |
| 174 | + local headers = txn.http:req_get_headers() |
| 175 | + local transaction_data = {} |
| 176 | + local origin = nil |
| 177 | + local allowed_origins = core.tokenize(allowed_origins, ",") |
| 178 | + |
| 179 | + if headers["origin"] ~= nil and headers["origin"][0] ~= nil then |
| 180 | + core.Debug("CORS: Got 'Origin' header: " .. headers["origin"][0]) |
| 181 | + origin = headers["origin"][0] |
| 182 | + end |
| 183 | + |
| 184 | + -- Bail if client did not send an Origin |
| 185 | + -- for example, it may be a regular OPTIONS request that is not a CORS preflight request |
| 186 | + if origin == nil or origin == '' then |
| 187 | + return |
| 188 | + end |
| 189 | + |
| 190 | + transaction_data["origin"] = origin |
| 191 | + transaction_data["allowed_methods"] = allowed_methods |
| 192 | + transaction_data["allowed_origins"] = allowed_origins |
| 193 | + transaction_data["allowed_headers"] = allowed_headers |
| 194 | + |
| 195 | + txn:set_priv(transaction_data) |
| 196 | + |
| 197 | + local method = txn.sf:method() |
| 198 | + transaction_data["method"] = method |
| 199 | + |
| 200 | + if method == "OPTIONS" and txn.reply ~= nil then |
| 201 | + preflight_request_ver2(txn, origin, allowed_methods, allowed_origins, allowed_headers) |
| 202 | + end |
| 203 | +end |
| 204 | + |
| 205 | +-- When invoked during a response, sets CORS headers so that the browser can read the response from permitted domains. |
| 206 | +-- txn: The current transaction object that gives access to response properties. |
| 207 | +function cors_response(txn) |
| 208 | + local transaction_data = txn:get_priv() |
| 209 | + |
| 210 | + if transaction_data == nil then |
| 211 | + return |
| 212 | + end |
| 213 | + |
| 214 | + local origin = transaction_data["origin"] |
| 215 | + local allowed_origins = transaction_data["allowed_origins"] |
| 216 | + local allowed_methods = transaction_data["allowed_methods"] |
| 217 | + local allowed_headers = transaction_data["allowed_headers"] |
| 218 | + local method = transaction_data["method"] |
| 219 | + |
| 220 | + -- Bail if client did not send an Origin |
| 221 | + if origin == nil or origin == '' then |
| 222 | + return |
| 223 | + end |
| 224 | + |
| 225 | + local allowed_origin = M.get_allowed_origin(origin, allowed_origins) |
| 226 | + |
| 227 | + if allowed_origin == nil then |
| 228 | + core.Debug("CORS: " .. origin .. " not allowed") |
| 229 | + else |
| 230 | + if method == "OPTIONS" and txn.reply == nil then |
| 231 | + preflight_request_ver1(txn, allowed_methods, allowed_headers) |
| 232 | + end |
| 233 | + |
| 234 | + core.Debug("CORS: " .. origin .. " allowed") |
| 235 | + txn.http:res_set_header("Access-Control-Allow-Origin", allowed_origin) |
| 236 | + |
| 237 | + if allowed_origin ~= "*" then |
| 238 | + txn.http:res_add_header("Vary", "Accept-Encoding,Origin") |
| 239 | + end |
| 240 | + end |
| 241 | +end |
| 242 | + |
| 243 | +-- Register the actions with HAProxy |
| 244 | +core.register_action("cors", {"http-req"}, cors_request, 3) |
| 245 | +core.register_action("cors", {"http-res"}, cors_response, 0) |
| 246 | + |
| 247 | +return M |
0 commit comments