diff --git a/Lib/test/test_urllib2.py b/Lib/test/test_urllib2.py index 88270b7537b5e7..14ce0c25ad24e6 100644 --- a/Lib/test/test_urllib2.py +++ b/Lib/test/test_urllib2.py @@ -1326,6 +1326,36 @@ def request(conn, method, url, *pos, **kw): fp = urllib.request.urlopen("http://python.org/path") self.assertEqual(fp.geturl(), "http://python.org/path?query") + def test_redirect_to_cross_domain_with_sensitive_header(self): + from_url = "http://example.com/index.html" + to_url = "http://cracker.com/index.html" + h = urllib.request.HTTPRedirectHandler() + o = h.parent = MockOpener() + req = Request(from_url) + req.add_header("Authorization", "Basic foo") + req.add_header("Cookie", "bar") + req.timeout = socket._GLOBAL_DEFAULT_TIMEOUT + h.http_error_302(req, MockFile(), 302, "", + MockHeaders({"location": to_url})) + + self.assertNotIn("Authorization", o.req.headers) + self.assertNotIn("Cookie", o.req.headers) + + def test_redirect_to_same_domain_with_sensitive_header(self): + from_url = "http://example.com/index.html" + to_url = "http://example.com/index.html" + h = urllib.request.HTTPRedirectHandler() + o = h.parent = MockOpener() + req = Request(from_url) + req.add_header("Authorization", "Basic foo") + req.add_header("Cookie", "bar") + req.timeout = socket._GLOBAL_DEFAULT_TIMEOUT + h.http_error_302(req, MockFile(), 302, "", + MockHeaders({"location": to_url})) + + self.assertIn("Authorization", o.req.headers) + self.assertIn("Cookie", o.req.headers) + def test_redirect_encoding(self): # Some characters in the redirect target may need special handling, # but most ASCII characters should be treated as already encoded diff --git a/Lib/urllib/request.py b/Lib/urllib/request.py index 1761e951e62466..c249c52364dd1e 100644 --- a/Lib/urllib/request.py +++ b/Lib/urllib/request.py @@ -678,11 +678,18 @@ def redirect_request(self, req, fp, code, msg, headers, newurl): CONTENT_HEADERS = ("content-length", "content-type") newheaders = {k: v for k, v in req.headers.items() if k.lower() not in CONTENT_HEADERS} - return Request(newurl, + newrequest = Request(newurl, headers=newheaders, origin_req_host=req.origin_req_host, unverifiable=True) + SENSITIVE_HEADERS = ("authorization", "cookie") + if newrequest.host != req.host: + newrequest.headers = {k: v for k, v in newrequest.headers.items() + if k.lower() not in SENSITIVE_HEADERS} + + return newrequest + # Implementation note: To avoid the server sending us into an # infinite loop, the request object needs to track what URLs we # have already seen. Do this by adding a handler-specific