diff --git a/core/error.go b/core/error.go index 99438ff..34c5bc9 100644 --- a/core/error.go +++ b/core/error.go @@ -13,25 +13,27 @@ const ( ListenerAcceptErr ErrorCode = -6 InvalidIPErr ErrorCode = -7 InvalidPortErr ErrorCode = -8 - FileOpenErr ErrorCode = -9 - FileStatErr ErrorCode = -10 - FileReadErr ErrorCode = -11 - FileTypeErr ErrorCode = -12 - DirectoryReadErr ErrorCode = -13 - RestrictedPathErr ErrorCode = -14 - InvalidRequestErr ErrorCode = -15 - CGIStartErr ErrorCode = -16 - CGIExitCodeErr ErrorCode = -17 - CGIStatus400Err ErrorCode = -18 - CGIStatus401Err ErrorCode = -19 - CGIStatus403Err ErrorCode = -20 - CGIStatus404Err ErrorCode = -21 - CGIStatus408Err ErrorCode = -22 - CGIStatus410Err ErrorCode = -23 - CGIStatus500Err ErrorCode = -24 - CGIStatus501Err ErrorCode = -25 - CGIStatus503Err ErrorCode = -26 - CGIStatusUnknownErr ErrorCode = -27 + MutexUpgradeErr ErrorCode = -9 + MutexDowngradeErr ErrorCode = -10 + FileOpenErr ErrorCode = -11 + FileStatErr ErrorCode = -12 + FileReadErr ErrorCode = -13 + FileTypeErr ErrorCode = -14 + DirectoryReadErr ErrorCode = -15 + RestrictedPathErr ErrorCode = -16 + InvalidRequestErr ErrorCode = -17 + CGIStartErr ErrorCode = -18 + CGIExitCodeErr ErrorCode = -19 + CGIStatus400Err ErrorCode = -20 + CGIStatus401Err ErrorCode = -21 + CGIStatus403Err ErrorCode = -22 + CGIStatus404Err ErrorCode = -23 + CGIStatus408Err ErrorCode = -24 + CGIStatus410Err ErrorCode = -25 + CGIStatus500Err ErrorCode = -26 + CGIStatus501Err ErrorCode = -27 + CGIStatus503Err ErrorCode = -28 + CGIStatusUnknownErr ErrorCode = -29 ) // Error specifies error interface with identifiable ErrorCode @@ -62,6 +64,10 @@ func getErrorMessage(code ErrorCode) string { return invalidIPErrStr case InvalidPortErr: return invalidPortErrStr + case MutexUpgradeErr: + return mutexUpgradeErrStr + case MutexDowngradeErr: + return mutexDowngradeErrStr case FileOpenErr: return fileOpenErrStr case FileStatErr: diff --git a/core/file.go b/core/file.go index 61f6c27..f997a34 100644 --- a/core/file.go +++ b/core/file.go @@ -2,7 +2,6 @@ package core import ( "os" - "sync" "time" ) @@ -21,7 +20,7 @@ type file struct { contents FileContents lastRefresh int64 isFresh bool - sync.RWMutex + UpgradeableMutex } // newFile returns a new File based on supplied FileContents @@ -30,7 +29,7 @@ func newFile(contents FileContents) *file { contents, 0, true, - sync.RWMutex{}, + UpgradeableMutex{}, } } diff --git a/core/filesystem.go b/core/filesystem.go index 5678e59..5a237f7 100644 --- a/core/filesystem.go +++ b/core/filesystem.go @@ -5,7 +5,6 @@ import ( "io" "os" "sort" - "sync" "time" ) @@ -29,14 +28,14 @@ var ( // FileSystemObject holds onto an LRUCacheMap and manages access to it, handless freshness checking and multi-threading type FileSystemObject struct { cache *lruCacheMap - sync.RWMutex + UpgradeableMutex } // NewFileSystemObject returns a new FileSystemObject func newFileSystemObject(size int) *FileSystemObject { return &FileSystemObject{ newLRUCacheMap(size), - sync.RWMutex{}, + UpgradeableMutex{}, } } @@ -308,16 +307,20 @@ func (fs *FileSystemObject) FetchFile(client *Client, fd *os.File, stat os.FileI return err } - // Get cache write lock - fs.RUnlock() - fs.Lock() + // Try upgrade our lock, else error out + if !fs.UpgradeLock() { + return NewError(MutexUpgradeErr) + } // Put file in cache fs.cache.Put(p.Absolute(), f) - // Switch back to cache read lock, get file read lock - fs.Unlock() - fs.RLock() + // Try downgrade our lock, else error out + if !fs.DowngradeLock() { + return NewError(MutexDowngradeErr) + } + + // Get file read lock f.RLock() } else { // Get file read lock @@ -325,9 +328,10 @@ func (fs *FileSystemObject) FetchFile(client *Client, fd *os.File, stat os.FileI // Check for file freshness if !f.IsFresh() { - // Switch to file write lock - f.RUnlock() - f.Lock() + // Try upgrade file lock, else error out + if !f.UpgradeLock() { + return NewError(MutexUpgradeErr) + } // Refresh file contents err := f.CacheContents(fd, p) @@ -337,9 +341,10 @@ func (fs *FileSystemObject) FetchFile(client *Client, fd *os.File, stat os.FileI return err } - // Done! Switch back to read lock - f.Unlock() - f.RLock() + // Try downgrade file lock, else error out + if !f.DowngradeLock() { + return NewError(MutexDowngradeErr) + } } } diff --git a/core/mutex.go b/core/mutex.go new file mode 100644 index 0000000..5b42364 --- /dev/null +++ b/core/mutex.go @@ -0,0 +1,59 @@ +package core + +import ( + "sync" + "sync/atomic" + "time" +) + +type UpgradeableMutex struct { + wLast int64 + internal sync.RWMutex +} + +func (mu *UpgradeableMutex) RLock() { + mu.internal.RLock() +} + +func (mu *UpgradeableMutex) RUnlock() { + mu.internal.RUnlock() +} + +func (mu *UpgradeableMutex) Lock() { + // Get lock, set last write-lock time + mu.internal.Lock() + atomic.StoreInt64(&mu.wLast, time.Now().UnixNano()) +} + +func (mu *UpgradeableMutex) Unlock() { + mu.internal.Unlock() +} + +func (mu *UpgradeableMutex) safeSwap(swapFn func()) bool { + // Get the 'now' time + now := time.Now().UnixNano() + + // Store now time + atomic.StoreInt64(&mu.wLast, now) + + // Perform the swap + swapFn() + + // Successful swap determined by if last write-lock + // is still equal to 'now' + return atomic.LoadInt64(&mu.wLast) == now +} + +func (mu *UpgradeableMutex) UpgradeLock() bool { + return mu.safeSwap(func() { + mu.internal.RUnlock() + mu.internal.Lock() + }) +} + +func (mu *UpgradeableMutex) DowngradeLock() bool { + return mu.safeSwap(func() { + mu.internal.Unlock() + mu.internal.RLock() + }) +} diff --git a/core/string_constants.go b/core/string_constants.go index 66b438e..d84fd07 100644 --- a/core/string_constants.go +++ b/core/string_constants.go @@ -130,6 +130,8 @@ const ( listenerAcceptErrStr = "Listener accept error" invalidIPErrStr = "Invalid IP" invalidPortErrStr = "Invalid port" + mutexUpgradeErrStr = "Mutex upgrade fail" + mutexDowngradeErrStr = "Mutex downgrade fail" fileOpenErrStr = "File open error" fileStatErrStr = "File stat error" fileReadErrStr = "File read error" diff --git a/gopher/error.go b/gopher/error.go index 98160ec..4eed28d 100644 --- a/gopher/error.go +++ b/gopher/error.go @@ -42,6 +42,10 @@ func generateErrorResponse(code core.ErrorCode) ([]byte, bool) { return nil, false // not user facing case core.InvalidPortErr: return nil, false // not user facing + case core.MutexUpgradeErr: + return buildErrorLine(errorResponse500), true + case core.MutexDowngradeErr: + return buildErrorLine(errorResponse500), true case core.FileOpenErr: return buildErrorLine(errorResponse404), true case core.FileStatErr: