package restful // Copyright 2013 Ernest Micklei. All rights reserved. // Use of this source code is governed by a license // that can be found in the LICENSE file. import ( "regexp" "strconv" "strings" ) // CrossOriginResourceSharing is used to create a Container Filter that implements CORS. // Cross-origin resource sharing (CORS) is a mechanism that allows JavaScript on a web page // to make XMLHttpRequests to another domain, not the domain the JavaScript originated from. // // http://en.wikipedia.org/wiki/Cross-origin_resource_sharing // http://enable-cors.org/server.html // http://www.html5rocks.com/en/tutorials/cors/#toc-handling-a-not-so-simple-request type CrossOriginResourceSharing struct { ExposeHeaders []string // list of Header names AllowedHeaders []string // list of Header names AllowedDomains []string // list of allowed values for Http Origin. An allowed value can be a regular expression to support subdomain matching. If empty all are allowed. AllowedMethods []string MaxAge int // number of seconds before requiring new Options request CookiesAllowed bool Container *Container allowedOriginPatterns []*regexp.Regexp // internal field for origin regexp check. } // Filter is a filter function that implements the CORS flow as documented on http://enable-cors.org/server.html // and http://www.html5rocks.com/static/images/cors_server_flowchart.png func (c CrossOriginResourceSharing) Filter(req *Request, resp *Response, chain *FilterChain) { origin := req.Request.Header.Get(HEADER_Origin) if len(origin) == 0 { if trace { traceLogger.Print("no Http header Origin set") } chain.ProcessFilter(req, resp) return } if !c.isOriginAllowed(origin) { // check whether this origin is allowed if trace { traceLogger.Printf("HTTP Origin:%s is not part of %v, neither matches any part of %v", origin, c.AllowedDomains, c.allowedOriginPatterns) } chain.ProcessFilter(req, resp) return } if req.Request.Method != "OPTIONS" { c.doActualRequest(req, resp) chain.ProcessFilter(req, resp) return } if acrm := req.Request.Header.Get(HEADER_AccessControlRequestMethod); acrm != "" { c.doPreflightRequest(req, resp) } else { c.doActualRequest(req, resp) chain.ProcessFilter(req, resp) return } } func (c CrossOriginResourceSharing) doActualRequest(req *Request, resp *Response) { c.setOptionsHeaders(req, resp) // continue processing the response } func (c *CrossOriginResourceSharing) doPreflightRequest(req *Request, resp *Response) { if len(c.AllowedMethods) == 0 { if c.Container == nil { c.AllowedMethods = DefaultContainer.computeAllowedMethods(req) } else { c.AllowedMethods = c.Container.computeAllowedMethods(req) } } acrm := req.Request.Header.Get(HEADER_AccessControlRequestMethod) if !c.isValidAccessControlRequestMethod(acrm, c.AllowedMethods) { if trace { traceLogger.Printf("Http header %s:%s is not in %v", HEADER_AccessControlRequestMethod, acrm, c.AllowedMethods) } return } acrhs := req.Request.Header.Get(HEADER_AccessControlRequestHeaders) if len(acrhs) > 0 { for _, each := range strings.Split(acrhs, ",") { if !c.isValidAccessControlRequestHeader(strings.Trim(each, " ")) { if trace { traceLogger.Printf("Http header %s:%s is not in %v", HEADER_AccessControlRequestHeaders, acrhs, c.AllowedHeaders) } return } } } resp.AddHeader(HEADER_AccessControlAllowMethods, strings.Join(c.AllowedMethods, ",")) resp.AddHeader(HEADER_AccessControlAllowHeaders, acrhs) c.setOptionsHeaders(req, resp) // return http 200 response, no body } func (c CrossOriginResourceSharing) setOptionsHeaders(req *Request, resp *Response) { c.checkAndSetExposeHeaders(resp) c.setAllowOriginHeader(req, resp) c.checkAndSetAllowCredentials(resp) if c.MaxAge > 0 { resp.AddHeader(HEADER_AccessControlMaxAge, strconv.Itoa(c.MaxAge)) } } func (c CrossOriginResourceSharing) isOriginAllowed(origin string) bool { if len(origin) == 0 { return false } if len(c.AllowedDomains) == 0 { return true } allowed := false for _, domain := range c.AllowedDomains { if domain == origin { allowed = true break } } if !allowed { if len(c.allowedOriginPatterns) == 0 { // compile allowed domains to allowed origin patterns allowedOriginRegexps, err := compileRegexps(c.AllowedDomains) if err != nil { return false } c.allowedOriginPatterns = allowedOriginRegexps } for _, pattern := range c.allowedOriginPatterns { if allowed = pattern.MatchString(origin); allowed { break } } } return allowed } func (c CrossOriginResourceSharing) setAllowOriginHeader(req *Request, resp *Response) { origin := req.Request.Header.Get(HEADER_Origin) if c.isOriginAllowed(origin) { resp.AddHeader(HEADER_AccessControlAllowOrigin, origin) } } func (c CrossOriginResourceSharing) checkAndSetExposeHeaders(resp *Response) { if len(c.ExposeHeaders) > 0 { resp.AddHeader(HEADER_AccessControlExposeHeaders, strings.Join(c.ExposeHeaders, ",")) } } func (c CrossOriginResourceSharing) checkAndSetAllowCredentials(resp *Response) { if c.CookiesAllowed { resp.AddHeader(HEADER_AccessControlAllowCredentials, "true") } } func (c CrossOriginResourceSharing) isValidAccessControlRequestMethod(method string, allowedMethods []string) bool { for _, each := range allowedMethods { if each == method { return true } } return false } func (c CrossOriginResourceSharing) isValidAccessControlRequestHeader(header string) bool { for _, each := range c.AllowedHeaders { if strings.ToLower(each) == strings.ToLower(header) { return true } } return false } // Take a list of strings and compile them into a list of regular expressions. func compileRegexps(regexpStrings []string) ([]*regexp.Regexp, error) { regexps := []*regexp.Regexp{} for _, regexpStr := range regexpStrings { r, err := regexp.Compile(regexpStr) if err != nil { return regexps, err } regexps = append(regexps, r) } return regexps, nil }