-------------------------------------------------------------------------------
-- Copas - Coroutine Oriented Portable Asynchronous Services
--
-- A dispatcher based on coroutines that can be used by TCP/IP servers.
-- Uses LuaSocket as the interface with the TCP/IP stack.
--
-- Authors: Andre Carregal, Javier Guerra, and Fabio Mascarenhas
-- Contributors: Diego Nehab, Mike Pall, David Burgess, Leonardo Godinho,
--               Thomas Harning Jr., and Gary NG
--
-- Copyright 2005-2016 - Kepler Project (www.keplerproject.org)
--
-- $Id: copas.lua,v 1.37 2009/04/07 22:09:52 carregal Exp $
-------------------------------------------------------------------------------

if package.loaded["socket.http"] and (_VERSION=="Lua 5.1") then     -- obsolete: only for Lua 5.1 compatibility
  error("you must require copas before require'ing socket.http")
end

local socket = require "socket"
local gettime = socket.gettime
local ssl -- only loaded upon demand

local WATCH_DOG_TIMEOUT = 120
local UDP_DATAGRAM_MAX = 8192  -- TODO: dynamically get this value from LuaSocket

local pcall = pcall
if _VERSION=="Lua 5.1" and not jit then     -- obsolete: only for Lua 5.1 compatibility
  pcall = require("coxpcall").pcall
end
  
-- Redefines LuaSocket functions with coroutine safe versions
-- (this allows the use of socket.http from within copas)
local function statusHandler(status, ...)
  if status then return ... end
  local err = (...)
  if type(err) == "table" then
    return nil, err[1]
  else
    error(err)
  end
end

function socket.protect(func)
  return function (...)
           return statusHandler(pcall(func, ...))
         end
end

function socket.newtry(finalizer)
  return function (...)
           local status = (...)
           if not status then
             pcall(finalizer, select(2, ...))
             error({ (select(2, ...)) }, 0)
           end
           return ...
         end
end

local copas = {}

-- Meta information is public even if beginning with an "_"
copas._COPYRIGHT   = "Copyright (C) 2005-2017 Kepler Project"
copas._DESCRIPTION = "Coroutine Oriented Portable Asynchronous Services"
copas._VERSION     = "Copas 2.0.2"

-- Close the socket associated with the current connection after the handler finishes
copas.autoclose = true

-- indicator for the loop running
copas.running = false

-------------------------------------------------------------------------------
-- Simple set implementation based on LuaSocket's tinyirc.lua example
-- adds a FIFO queue for each value in the set
-------------------------------------------------------------------------------
local function newset()
  local reverse = {}
  local set = {}
  local q = {}
  setmetatable(set, { __index = {
                        insert = function(set, value)
                                   if not reverse[value] then
                                     set[#set + 1] = value
                                     reverse[value] = #set
                                   end
                                 end,

                        remove = function(set, value)
                                   local index = reverse[value]
                                   if index then
                                     reverse[value] = nil
                                     local top = set[#set]
                                     set[#set] = nil
                                     if top ~= value then
                                       reverse[top] = index
                                       set[index] = top
                                     end
                                   end
                                 end,

                        push = function (set, key, itm)
                                 local qKey = q[key]
                                 if qKey == nil then
                                   q[key] = {itm}
                                 else
                                   qKey[#qKey + 1] = itm
                                 end
                               end,

                        pop = function (set, key)
                                local t = q[key]
                                if t ~= nil then
                                  local ret = table.remove (t, 1)
                                  if t[1] == nil then
                                    q[key] = nil
                                  end
                                  return ret
                                end
                              end
                    }})
  return set
end

local fnil = function()end
local _sleeping = {
    times = {},  -- list with wake-up times
    cos = {},    -- list with coroutines, index matches the 'times' list
    lethargy = {}, -- list of coroutines sleeping without a wakeup time

    insert = fnil,
    remove = fnil,
    push = function(self, sleeptime, co)
        if not co then return end
        if sleeptime<0 then
            --sleep until explicit wakeup through copas.wakeup
            self.lethargy[co] = true
            return
        else
            sleeptime = gettime() + sleeptime
        end
        local t, c = self.times, self.cos
        local i, cou = 1, #t
        --TODO: do a binary search
        while i<=cou and t[i]<=sleeptime do i=i+1 end
        table.insert(t, i, sleeptime)
        table.insert(c, i, co)
    end,
    getnext = function(self)  -- returns delay until next sleep expires, or nil if there is none
        local t = self.times
        local delay = t[1] and t[1] - gettime() or nil

        return delay and math.max(delay, 0) or nil
    end,
    -- find the thread that should wake up to the time
    pop = function(self, time)
        local t, c = self.times, self.cos
        if #t==0 or time<t[1] then return end
        local co = c[1]
        table.remove(t, 1)
        table.remove(c, 1)
        return co
    end,
    wakeup = function(self, co)
        local let = self.lethargy
        if let[co] then
            self:push(0, co)
            let[co] = nil
        else
            let = self.cos
            for i=1,#let do
                if let[i]==co then
                    table.remove(let, i)
                    table.remove(self.times, i)
                    self:push(0, co)
                    return
                end
            end
        end
    end
} --_sleeping

local function newtimer(timeout)
  timeout = timeout or WATCH_DOG_TIMEOUT
  return {
    timeout_time = gettime() + timeout,
    expired = function(self)
      return gettime() >= self.timeout_time
    end
  }
end

local _servers = newset() -- servers being handled
local _reading_log = {}
local _writing_log = {}

local _reading = newset() -- sockets currently being read
local _writing = newset() -- sockets currently being written
local _isTimeout = {      -- set of errors indicating a timeout
  ["timeout"] = true,     -- default LuaSocket timeout
  ["wantread"] = true,    -- LuaSec specific timeout
  ["wantwrite"] = true,   -- LuaSec specific timeout
}

-------------------------------------------------------------------------------
-- Coroutine based socket I/O functions.
-------------------------------------------------------------------------------

local function isTCP(socket)
  return string.sub(tostring(socket),1,3) ~= "udp"
end

-- reads a pattern from a client and yields to the reading set on timeouts
-- UDP: a UDP socket expects a second argument to be a number, so it MUST
-- be provided as the 'pattern' below defaults to a string. Will throw a
-- 'bad argument' error if omitted.
function copas.receive(client, pattern, part, timeout)
  local s, err
  pattern = pattern or "*l"
  local current_log = _reading_log
  local timer = newtimer(timeout)
  repeat
    s, err, part = client:receive(pattern, part)
    if s or (not _isTimeout[err]) or timer:expired() then
      current_log[client] = nil
      return s, err, part
    end
    if err == "wantwrite" then
      current_log = _writing_log
      current_log[client] = timer
      coroutine.yield(client, _writing)
    else
      current_log = _reading_log
      current_log[client] = timer
      coroutine.yield(client, _reading)
    end
  until false
end

-- receives data from a client over UDP. Not available for TCP.
-- (this is a copy of receive() method, adapted for receivefrom() use)
function copas.receivefrom(client, size, timeout)
  local s, err, port
  size = size or UDP_DATAGRAM_MAX
  local timer = newtimer(timeout)
  repeat
    s, err, port = client:receivefrom(size) -- upon success err holds ip address
    if s or err ~= "timeout" or timer:expired() then
      _reading_log[client] = nil
      return s, err, port
    end
    _reading_log[client] = timer
    coroutine.yield(client, _reading)
  until false
end

-- same as above but with special treatment when reading chunks,
-- unblocks on any data received.
function copas.receivePartial(client, pattern, part)
  local s, err
  pattern = pattern or "*l"
  local current_log = _reading_log
  local timer = newtimer()
  repeat
    s, err, part = client:receive(pattern, part)
    if s or ((type(pattern)=="number") and part~="" and part ~=nil ) or (not _isTimeout[err]) or timer:expired() then
      current_log[client] = nil
      return s, err, part
    end
    if err == "wantwrite" then
      current_log = _writing_log
      current_log[client] = timer
      coroutine.yield(client, _writing)
    else
      current_log = _reading_log
      current_log[client] = timer
      coroutine.yield(client, _reading)
    end
  until false
end

-- sends data to a client. The operation is buffered and
-- yields to the writing set on timeouts
-- Note: from and to parameters will be ignored by/for UDP sockets
function copas.send(client, data, from, to, timeout)
  local s, err
  from = from or 1
  local lastIndex = from - 1
  local current_log = _writing_log
  local timer = newtimer(timeout)
  repeat
    s, err, lastIndex = client:send(data, lastIndex + 1, to)
    -- adds extra coroutine swap
    -- garantees that high throughput doesn't take other threads to starvation
    if (math.random(100) > 90) then
      current_log[client] = timer   -- TODO: how to handle this?? 
      if current_log == _writing_log then
        coroutine.yield(client, _writing)
      else
        coroutine.yield(client, _reading)
      end
    end
    if s or (not _isTimeout[err]) or timer:expired() then
      current_log[client] = nil
      return s, err,lastIndex
    end
    if err == "wantread" then
      current_log = _reading_log
      current_log[client] = timer
      coroutine.yield(client, _reading)
    else
      current_log = _writing_log
      current_log[client] = timer
      coroutine.yield(client, _writing)
    end
  until false
end

-- sends data to a client over UDP. Not available for TCP.
-- (this is a copy of send() method, adapted for sendto() use)
function copas.sendto(client, data, ip, port, timeout)
  local s, err
  local timer = newtimer(timeout)
  repeat
    s, err = client:sendto(data, ip, port)
    -- adds extra coroutine swap
    -- garantees that high throughput doesn't take other threads to starvation
    if (math.random(100) > 90) then
      _writing_log[client] = timer
      coroutine.yield(client, _writing)
    end
    if s or err ~= "timeout" or timer:expired() then
      _writing_log[client] = nil
      return s, err
    end
    _writing_log[client] = timer
    coroutine.yield(client, _writing)
  until false
end

-- waits until connection is completed
function copas.connect(skt, host, port, timeout)
  skt:settimeout(0)
  local ret, err, tried_more_than_once
  local timer = newtimer(timeout)
  repeat
    ret, err = skt:connect(host, port)
    if (not ret) and timer:expired() then
      return ret, "timeout"
    end
    -- non-blocking connect on Windows results in error "Operation already
    -- in progress" to indicate that it is completing the request async. So essentially
    -- it is the same as "timeout"
    -- "Invalid argument" explanation: https://github.com/diegonehab/luasocket/pull/190
    if ret or (err ~= "timeout" and err ~= "Operation already in progress" and err ~= "Invalid argument") then
      -- Once the async connect completes, Windows returns the error "already connected"
      -- to indicate it is done, so that error should be ignored. Except when it is the 
      -- first call to connect, then it was already connected to something else and the 
      -- error should be returned
      if (not ret) and (err == "already connected" and tried_more_than_once) then
        ret = 1
        err = nil
      end
      _writing_log[skt] = nil
      return ret, err
    end
    tried_more_than_once = tried_more_than_once or true
    _writing_log[skt] = timer
    coroutine.yield(skt, _writing)
  until false
end

---
-- Peforms an (async) ssl handshake on a connected TCP client socket.
-- NOTE: replace all previous socket references, with the returned new ssl wrapped socket
-- Throws error and does not return nil+error, as that might silently fail
-- in code like this;
--   copas.addserver(s1, function(skt)
--       skt = copas.wrap(skt, sparams)
--       skt:dohandshake()   --> without explicit error checking, this fails silently and
--       skt:send(body)      --> continues unencrypted
-- @param skt Regular LuaSocket CLIENT socket object
-- @param sslt Table with ssl parameters
-- @return wrapped ssl socket, or throws an error
function copas.dohandshake(skt, sslt)
  ssl = ssl or require("ssl")
  local nskt, err = ssl.wrap(skt, sslt)
  if not nskt then return error(err) end
  local queue
  nskt:settimeout(0)
  repeat
    local success, err = nskt:dohandshake()
    if success then
      return nskt
    elseif err == "wantwrite" then
      queue = _writing
    elseif err == "wantread" then
      queue = _reading
    else
      error(err)
    end
    coroutine.yield(nskt, queue)
  until false    
end

-- flushes a client write buffer (deprecated)
function copas.flush(client)
end

-- wraps a TCP socket to use Copas methods (send, receive, flush and settimeout)
local _skt_mt_tcp = {
                   __tostring = function(self)
                                  return tostring(self.socket).." (copas wrapped)"
                                end,
                   __index = {
                              
                   send = function (self, data, from, to)
                            return copas.send(self.socket, data, from, to, self.timeout)
                          end,

                   receive = function (self, pattern, prefix)
                               if (self.timeout==0) then
                                 return copas.receivePartial(self.socket, pattern, prefix)
                               end
                               return copas.receive(self.socket, pattern, prefix, self.timeout)
                             end,

                   flush = function (self)
                             return copas.flush(self.socket)
                           end,

                   settimeout = function (self, time)
                                  self.timeout=time
                                  return true
                                end,

                   -- TODO: socket.connect is a shortcut, and must be provided with an alternative
                   -- if ssl parameters are available, it will also include a handshake
                   connect = function(self, host, port)
                     local res, err = copas.connect(self.socket, host, port, self.timeout)
                     if res and self.ssl_params then
                       res, err = self:dohandshake()
                     end  
                     return res, err
                   end,

                   close = function(self, ...) return self.socket:close(...) end,

                   -- TODO: socket.bind is a shortcut, and must be provided with an alternative
                   bind = function(self, ...) return self.socket:bind(...) end,

                   -- TODO: is this DNS related? hence blocking?
                   getsockname = function(self, ...) return self.socket:getsockname(...) end,

                   getstats = function(self, ...) return self.socket:getstats(...) end,

                   setstats = function(self, ...) return self.socket:setstats(...) end,

                   listen = function(self, ...) return self.socket:listen(...) end,

                   accept = function(self, ...) return self.socket:accept(...) end,

                   setoption = function(self, ...) return self.socket:setoption(...) end,
                   
                   -- TODO: is this DNS related? hence blocking?
                   getpeername = function(self, ...) return self.socket:getpeername(...) end,

                   shutdown = function(self, ...) return self.socket:shutdown(...) end,

                   dohandshake = function(self, sslt)
                     self.ssl_params = sslt or self.ssl_params
                     local nskt, err = copas.dohandshake(self.socket, self.ssl_params)
                     if not nskt then return nskt, err end
                     self.socket = nskt  -- replace internal socket with the newly wrapped ssl one
                     return self
                   end,
                   
               }}

-- wraps a UDP socket, copy of TCP one adapted for UDP.
local _skt_mt_udp = {__index = { }}
for k,v in pairs(_skt_mt_tcp) do _skt_mt_udp[k] = _skt_mt_udp[k] or v end
for k,v in pairs(_skt_mt_tcp.__index) do _skt_mt_udp.__index[k] = v end

_skt_mt_udp.__index.sendto =      function (self, data, ip, port)
                                    -- UDP sending is non-blocking, but we provide starvation prevention, so replace anyway
                                    return copas.sendto(self.socket, data, ip, port, self.timeout)
                                  end

_skt_mt_udp.__index.receive =     function (self, size)
                                    return copas.receive(self.socket, (size or UDP_DATAGRAM_MAX), nil, self.timeout)
                                  end

_skt_mt_udp.__index.receivefrom = function (self, size)
                                    return copas.receivefrom(self.socket, (size or UDP_DATAGRAM_MAX), self.timeout)
                                  end
                   
                                  -- TODO: is this DNS related? hence blocking?
_skt_mt_udp.__index.setpeername = function(self, ...) return self.socket:getpeername(...) end

_skt_mt_udp.__index.setsockname = function(self, ...) return self.socket:setsockname(...) end

                                    -- do not close client, as it is also the server for udp.
_skt_mt_udp.__index.close       = function(self, ...) return true end


---
-- Wraps a LuaSocket socket object in an async Copas based socket object.
-- @param skt The socket to wrap
-- @sslt (optional) Table with ssl parameters, use an empty table to use ssl with defaults
-- @return wrapped socket object
function copas.wrap (skt, sslt)
  if (getmetatable(skt) == _skt_mt_tcp) or (getmetatable(skt) == _skt_mt_udp) then 
    return skt -- already wrapped
  end
  skt:settimeout(0)
  if not isTCP(skt) then
    return  setmetatable ({socket = skt}, _skt_mt_udp)
  else
    return  setmetatable ({socket = skt, ssl_params = sslt}, _skt_mt_tcp)
  end
end

--- Wraps a handler in a function that deals with wrapping the socket and doing the
-- optional ssl handshake.
function copas.handler(handler, sslparams)
  return function (skt, ...) 
    skt = copas.wrap(skt)
    if sslparams then skt:dohandshake(sslparams) end
    return handler(skt, ...)
  end
end


--------------------------------------------------
-- Error handling
--------------------------------------------------

local _errhandlers = {}   -- error handler per coroutine

function copas.setErrorHandler (err)
  local co = coroutine.running()
  if co then
    _errhandlers [co] = err
  end
end

local function _deferror (msg, co, skt)
  print (msg, co, skt)
end

-------------------------------------------------------------------------------
-- Thread handling
-------------------------------------------------------------------------------

local function _doTick (co, skt, ...)
  if not co then return end

  local ok, res, new_q = coroutine.resume(co, skt, ...)
  if ok and res and new_q then
    new_q:insert (res)
    new_q:push (res, co)
  else
    if not ok then pcall (_errhandlers [co] or _deferror, res, co, skt) end
    if skt and copas.autoclose and isTCP(skt) then 
      skt:close() -- do not auto-close UDP sockets, as the handler socket is also the server socket
    end
    _errhandlers [co] = nil
  end
end

-- accepts a connection on socket input
local function _accept(input, handler)
  local client = input:accept()
  if client then
    client:settimeout(0)
    local co = coroutine.create(handler)
    _doTick (co, client)
    --_reading:insert(client)
  end
  return client
end

-- handle threads on a queue
local function _tickRead (skt)
  _doTick (_reading:pop (skt), skt)
end

local function _tickWrite (skt)
  _doTick (_writing:pop (skt), skt)
end

-------------------------------------------------------------------------------
-- Adds a server/handler pair to Copas dispatcher
-------------------------------------------------------------------------------
local function addTCPserver(server, handler, timeout)
  server:settimeout(timeout or 0)
  _servers[server] = handler
  _reading:insert(server)
end

local function addUDPserver(server, handler, timeout)
    server:settimeout(timeout or 0)
    local co = coroutine.create(handler)
    _reading:insert(server)
    _doTick (co, server)
end

function copas.addserver(server, handler, timeout)
    if isTCP(server) then
        addTCPserver(server, handler, timeout)
    else
        addUDPserver(server, handler, timeout)
    end
end

function copas.removeserver(server, keep_open)
  local s, mt = server, getmetatable(server)
  if mt == _skt_mt_tcp or mt == _skt_mt_udp then
    s = server.socket
  end
  _servers[s] = nil 
  _reading:remove(s) 
  if keep_open then
    return true
  end
  return server:close() 
end

-------------------------------------------------------------------------------
-- Adds an new coroutine thread to Copas dispatcher
-------------------------------------------------------------------------------
function copas.addthread(handler, ...)
  -- create a coroutine that skips the first argument, which is always the socket
  -- passed by the scheduler, but `nil` in case of a task/thread
  local thread = coroutine.create(function(_, ...) return handler(...) end)
  _doTick (thread, nil, ...)
  return thread
end

-------------------------------------------------------------------------------
-- tasks registering
-------------------------------------------------------------------------------

local _tasks = {}

local function addtaskRead (tsk)
  -- lets tasks call the default _tick()
  tsk.def_tick = _tickRead

  _tasks [tsk] = true
end

local function addtaskWrite (tsk)
  -- lets tasks call the default _tick()
  tsk.def_tick = _tickWrite

  _tasks [tsk] = true
end

local function tasks ()
  return next, _tasks
end

-------------------------------------------------------------------------------
-- main tasks: manage readable and writable socket sets
-------------------------------------------------------------------------------
-- a task to check ready to read events
local _readable_t = {
  events = function(self)
             local i = 0
             return function ()
                      i = i + 1
                      return self._evs [i]
                    end
           end,

  tick = function (self, input)
           local handler = _servers[input]
           if handler then
             input = _accept(input, handler)
           else
             _reading:remove (input)
             self.def_tick (input)
           end
         end
}

addtaskRead (_readable_t)


-- a task to check ready to write events
local _writable_t = {
  events = function (self)
             local i = 0
             return function ()
                      i = i + 1
                      return self._evs [i]
                    end
           end,

  tick = function (self, output)
           _writing:remove (output)
           self.def_tick (output)
         end
}

addtaskWrite (_writable_t)
--
--sleeping threads task
local _sleeping_t = {
    tick = function (self, time, ...)
       _doTick(_sleeping:pop(time), ...)
    end
}

-- yields the current coroutine and wakes it after 'sleeptime' seconds.
-- If sleeptime<0 then it sleeps until explicitly woken up using 'wakeup'
function copas.sleep(sleeptime)
    coroutine.yield((sleeptime or 0), _sleeping)
end

-- Wakes up a sleeping coroutine 'co'.
function copas.wakeup(co)
    _sleeping:wakeup(co)
end

local last_cleansing = 0

-------------------------------------------------------------------------------
-- Checks for reads and writes on sockets
-------------------------------------------------------------------------------
local function _select(timeout)
  local r_evs, w_evs, err = socket.select(_reading, _writing, timeout)
  _readable_t._evs, _writable_t._evs = r_evs, w_evs

  -- Check all sockets selected for reading, and check how long they have been waiting
  -- for data already, without select returning them as readable
  for skt, timer in pairs(_reading_log) do
    if not r_evs[skt] and timer:expired() then
      -- This one timedout while waiting to become readable, so move
      -- it in the readable list and try and read anyway, despite not 
      -- having been returned by select
      _reading_log[skt] = nil
      r_evs[#r_evs + 1] = skt
      r_evs[skt] = #r_evs
    end
  end

  -- Do the same for writing
  for skt, timer in pairs(_writing_log) do
    if not w_evs[skt] and timer:expired() then
      _writing_log[skt] = nil
      w_evs[#w_evs + 1] = skt
      w_evs[skt] = #w_evs
    end
  end

  if err == "timeout" and #r_evs + #w_evs > 0 then
    return nil
  else
    return err
  end
end


-------------------------------------------------------------------------------
-- Dispatcher loop step.
-- Listen to client requests and handles them
-- Returns false if no data was handled (timeout), or true if there was data
-- handled (or nil + error message)
-------------------------------------------------------------------------------
function copas.step(timeout)
  _sleeping_t:tick(gettime())

  -- Need to wake up the select call in time for the next sleeping event
  local nextwait = _sleeping:getnext()
  if nextwait then
    timeout = timeout and math.min(nextwait, timeout) or nextwait
  else
    if copas.finished() then
      return false
    end
  end

  local err = _select (timeout)
  if err then
    if err == "timeout" then return false end
    return nil, err
  end

  for tsk in tasks() do
    for ev in tsk:events() do
      tsk:tick (ev)
    end
  end
  return true
end

-------------------------------------------------------------------------------
-- Check whether there is something to do.
-- returns false if there are no sockets for read/write nor tasks scheduled
-- (which means Copas is in an empty spin)
-------------------------------------------------------------------------------
function copas.finished()
  return not (next(_reading) or next(_writing) or _sleeping:getnext())
end

-------------------------------------------------------------------------------
-- Dispatcher endless loop.
-- Listen to client requests and handles them forever
-------------------------------------------------------------------------------
function copas.loop(timeout)
  copas.running = true
  while not copas.finished() do copas.step(timeout) end
  copas.running = false
end

return copas