You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
lua-lib/copas/semaphore.lua

202 lines
5.0 KiB

local copas = require("copas")
local DEFAULT_TIMEOUT = 10
local semaphore = {}
semaphore.__index = semaphore
-- registry, semaphore indexed by the coroutines using them.
local registry = setmetatable({}, { __mode="kv" })
-- create a new semaphore
-- @param max maximum number of resources the semaphore can hold (this maximum does NOT include resources that have been given but not yet returned).
-- @param start (optional, default 0) the initial resources available
-- @param seconds (optional, default 10) default semaphore timeout in seconds, or `math.huge` to have no timeout.
function semaphore.new(max, start, seconds)
local timeout = tonumber(seconds or DEFAULT_TIMEOUT) or -1
if timeout < 0 then
error("expected timeout (2nd argument) to be a number greater than or equal to 0, got: " .. tostring(seconds), 2)
end
if type(max) ~= "number" or max < 1 then
error("expected max resources (1st argument) to be a number greater than 0, got: " .. tostring(max), 2)
end
local self = setmetatable({
count = start or 0,
max = max,
timeout = timeout,
q_tip = 1, -- position of next entry waiting
q_tail = 1, -- position where next one will be inserted
queue = {},
to_flags = setmetatable({}, { __mode = "k" }), -- timeout flags indexed by coroutine
}, semaphore)
return self
end
do
local destroyed_func = function()
return nil, "destroyed"
end
local destroyed_semaphore_mt = {
__index = function()
return destroyed_func
end
}
-- destroy a semaphore.
-- Releases all waiting threads with `nil+"destroyed"`
function semaphore:destroy()
self:give(math.huge)
self.destroyed = true
setmetatable(self, destroyed_semaphore_mt)
return true
end
end
-- Gives resources.
-- @param given (optional, default 1) number of resources to return. If more
-- than the maximum are returned then it will be capped at the maximum and
-- error "too many" will be returned.
function semaphore:give(given)
local err
given = given or 1
local count = self.count + given
--print("now at",count, ", after +"..given)
if count > self.max then
count = self.max
err = "too many"
end
while self.q_tip < self.q_tail do
local i = self.q_tip
local nxt = self.queue[i] -- there can be holes, so nxt might be nil
if not nxt then
self.q_tip = i + 1
else
if count >= nxt.requested then
-- release it
self.queue[i] = nil
self.to_flags[nxt.co] = nil
count = count - nxt.requested
self.q_tip = i + 1
copas.wakeup(nxt.co)
nxt.co = nil
else
break -- we ran out of resources
end
end
end
if self.q_tip == self.q_tail then -- reset queue
self.queue = {}
self.q_tip = 1
self.q_tail = 1
end
self.count = count
if err then
return nil, err
end
return true
end
local function timeout_handler(co)
local self = registry[co]
--print("checking timeout ", co)
if not self then
return
end
for i = self.q_tip, self.q_tail do
local item = self.queue[i]
if item and co == item.co then
self.queue[i] = nil
self.to_flags[co] = true
--print("marked timeout ", co)
copas.wakeup(co)
return
end
end
-- nothing to do here...
end
-- Requests resources from the semaphore.
-- Waits if there are not enough resources available before returning.
-- @param requested (optional, default 1) the number of resources requested
-- @param timeout (optional, defaults to semaphore timeout) timeout in
-- seconds. If 0 it will either succeed or return immediately with error "timeout".
-- If `math.huge` it will wait forever.
-- @return true, or nil+"destroyed"
function semaphore:take(requested, timeout)
requested = requested or 1
if self.q_tail == 1 and self.count >= requested then
-- nobody is waiting before us, and there is enough in store
self.count = self.count - requested
return true
end
if requested > self.max then
return nil, "too many"
end
local to = timeout or self.timeout
if to == 0 then
return nil, "timeout"
end
-- get in line
local co = coroutine.running()
self.to_flags[co] = nil
registry[co] = self
copas.timeout(to, timeout_handler)
self.queue[self.q_tail] = {
co = co,
requested = requested,
--timeout = nil, -- flag indicating timeout
}
self.q_tail = self.q_tail + 1
copas.pauseforever() -- block until woken
registry[co] = nil
if self.to_flags[co] then
-- a timeout happened
self.to_flags[co] = nil
return nil, "timeout"
end
copas.timeout(0)
if self.destroyed then
return nil, "destroyed"
end
return true
end
-- returns current available resources
function semaphore:get_count()
return self.count
end
-- returns total shortage for requested resources
function semaphore:get_wait()
local wait = 0
for i = self.q_tip, self.q_tail - 1 do
wait = wait + ((self.queue[i] or {}).requested or 0)
end
return wait - self.count
end
return semaphore