Copyright 2019 The Go Authors. All rights reserved. Use of this source code is governed by a BSD-style license that can be found in the LICENSE file.

package middleware

import (
	
	
	
	
	
	
	
	

	
	
	
	
	
	icache 
	
	
	
)

var (
	keyCacheHit       = tag.MustNewKey("cache.hit")
	keyCacheName      = tag.MustNewKey("cache.name")
	keyCacheOperation = tag.MustNewKey("cache.operation")
	cacheResults      = stats.Int64(
		"go-discovery/cache/result_count",
		"The result of a cache request.",
		stats.UnitDimensionless,
	)
	cacheLatency = stats.Float64(
		"go-discovery/cache/result_latency",
		"Cache serving latency latency",
		stats.UnitMilliseconds,
	)
	cacheErrors = stats.Int64(
		"go-discovery/cache/errors",
		"Errors retrieving from cache.",
		stats.UnitDimensionless,
	)

	CacheResultCount = &view.View{
		Name:        "go-discovery/cache/result_count",
		Measure:     cacheResults,
		Aggregation: view.Count(),
		Description: "cache results, by cache name and whether it was a hit",
		TagKeys:     []tag.Key{keyCacheName, keyCacheHit},
	}
	CacheLatency = &view.View{
		Name:        "go-discovery/cache/result_latency",
		Measure:     cacheLatency,
		Description: "cache result latency, by cache name and whether it was a hit",
		Aggregation: ochttp.DefaultLatencyDistribution,
		TagKeys:     []tag.Key{keyCacheName, keyCacheHit},
	}
	CacheErrorCount = &view.View{
		Name:        "go-discovery/cache/errors",
		Measure:     cacheErrors,
		Aggregation: view.Count(),
		Description: "cache errors, by cache name",
		TagKeys:     []tag.Key{keyCacheName, keyCacheOperation},
	}
To avoid test flakiness, when TestMode is true, cache writes are synchronous.
An Expirer computes the TTL that should be used when caching a page.
TTL returns an Expirer that expires all pages after the given TTL.
func ( time.Duration) Expirer {
	return func( *http.Request) time.Duration {
		return 
	}
}
Cache returns a new Middleware that caches every request. The name of the cache is used only for metrics. The expirer is a func that is used to map a new request to its TTL. authHeader is the header key used by the cache to know that a request should bypass the cache. authValues is the set of values that could be set on the authHeader in order to bypass the cache.
func ( string,  *redis.Client,  Expirer,  []string) Middleware {
	return func( http.Handler) http.Handler {
		return &cache{
			name:       ,
			authValues: ,
			cache:      icache.New(),
			delegate:   ,
			expirer:    ,
		}
	}
}

Check auth header to see if request should bypass cache.
	 := .Header.Get(config.BypassCacheAuthHeader)
	for ,  := range .authValues {
		if  ==  {
			.delegate.ServeHTTP(, )
			return
		}
If the flash cookie is set, bypass the cache.
	if ,  := .Cookie(cookie.AlternativeModuleFlash);  == nil {
		.delegate.ServeHTTP(, )
		return
	}
	 := .Context()
	 := .URL.String()
	 := time.Now()
	,  := .get(, )
	recordCacheResult(, .name, , time.Since())
	if  {
		if ,  := io.Copy(, );  != nil {
			log.Errorf(, "error copying zip bytes: %v", )
		}
		return
	}
	 := newRecorder()
	.delegate.ServeHTTP(, )
	if .bufErr == nil && (.statusCode == 0 || .statusCode == http.StatusOK) {
		 := .expirer()
		if TestMode {
			.put(, , , )
		} else {
			go .put(, , , )
		}
	}
}

Set a short timeout for redis requests, so that we can quickly fall back to un-cached serving if redis is unavailable.
	,  := context.WithTimeout(, 100*time.Millisecond)
	defer ()
	,  := .cache.Get(, )
	if  != nil {
		select {
		case <-.Done():
			log.Infof(, "cache get(%q): context timed out", )
		default:
			log.Infof(, "cache get(%q): %v", , )
		}
		recordCacheError(, .name, "GET")
		return nil, false
	}
	if  == nil {
		return nil, false
	}
	,  := gzip.NewReader(bytes.NewReader())
	if  != nil {
		log.Errorf(, "cache: gzip.NewReader: %v", )
		recordCacheError(, .name, "UNZIP")
		return nil, false
	}
	return , true
}

func ( *cache) ( context.Context,  string,  *cacheRecorder,  time.Duration) {
	if  := .zipWriter.Close();  != nil {
		log.Errorf(, "cache: error closing zip for %q: %v", , )
		return
	}
	log.Infof(, "caching response of length %d for %s", .buf.Len(), )
	,  := context.WithTimeout(context.Background(), 1*time.Second)
	defer ()
	if  := .cache.Put(, , .buf.Bytes(), );  != nil {
		recordCacheError(, .name, "SET")
		log.Warningf(, "cache set %q: %v", , )
	}
}

func ( http.ResponseWriter) *cacheRecorder {
	 := &bytes.Buffer{}
	 := gzip.NewWriter()
	return &cacheRecorder{ResponseWriter: , buf: , zipWriter: }
}
cacheRecorder is an http.ResponseWriter that collects http bytes for later writing to the cache. Along the way it collects any error, along with the resulting HTTP status code. We only cache 200 OK responses.
Only try writing to the buffer if we haven't yet encountered an error.
	if .bufErr == nil {
		if  == nil {
			,  := .zipWriter.Write()
			if  != nil {
				.bufErr = 
			}
			if  !=  {
				.bufErr = fmt.Errorf("wrote %d to zip, but wanted %d", , )
			}
		} else {
			.bufErr = fmt.Errorf("ResponseWriter.Write failed: %v", )
		}
	}
	return , 
}

func ( *cacheRecorder) ( int) {
Defensively take the largest status code that's written, so if any middleware thinks the response is not OK, we will capture this.
		.statusCode = 
	}
	.ResponseWriter.WriteHeader()