commit c5b181e08c0e8d622d16c678a82f4a15f956abd7 Author: rubin Date: Tue Mar 5 16:09:18 2024 +0300 initial commit diff --git a/copas.lua b/copas.lua new file mode 100644 index 0000000..8614839 --- /dev/null +++ b/copas.lua @@ -0,0 +1,820 @@ +------------------------------------------------------------------------------- +-- Copas - Coroutine Oriented Portable Asynchronous Services +-- +-- A dispatcher based on coroutines that can be used by TCP/IP servers. +-- Uses LuaSocket as the interface with the TCP/IP stack. +-- +-- Authors: Andre Carregal, Javier Guerra, and Fabio Mascarenhas +-- Contributors: Diego Nehab, Mike Pall, David Burgess, Leonardo Godinho, +-- Thomas Harning Jr., and Gary NG +-- +-- Copyright 2005-2016 - Kepler Project (www.keplerproject.org) +-- +-- $Id: copas.lua,v 1.37 2009/04/07 22:09:52 carregal Exp $ +------------------------------------------------------------------------------- + +if package.loaded["socket.http"] and (_VERSION=="Lua 5.1") then -- obsolete: only for Lua 5.1 compatibility + error("you must require copas before require'ing socket.http") +end + +local socket = require "socket" +local gettime = socket.gettime +local ssl -- only loaded upon demand + +local WATCH_DOG_TIMEOUT = 120 +local UDP_DATAGRAM_MAX = 8192 -- TODO: dynamically get this value from LuaSocket + +local pcall = pcall +if _VERSION=="Lua 5.1" and not jit then -- obsolete: only for Lua 5.1 compatibility + pcall = require("coxpcall").pcall +end + +-- Redefines LuaSocket functions with coroutine safe versions +-- (this allows the use of socket.http from within copas) +local function statusHandler(status, ...) + if status then return ... end + local err = (...) + if type(err) == "table" then + return nil, err[1] + else + error(err) + end +end + +function socket.protect(func) + return function (...) + return statusHandler(pcall(func, ...)) + end +end + +function socket.newtry(finalizer) + return function (...) + local status = (...) + if not status then + pcall(finalizer, select(2, ...)) + error({ (select(2, ...)) }, 0) + end + return ... + end +end + +local copas = {} + +-- Meta information is public even if beginning with an "_" +copas._COPYRIGHT = "Copyright (C) 2005-2017 Kepler Project" +copas._DESCRIPTION = "Coroutine Oriented Portable Asynchronous Services" +copas._VERSION = "Copas 2.0.2" + +-- Close the socket associated with the current connection after the handler finishes +copas.autoclose = true + +-- indicator for the loop running +copas.running = false + +------------------------------------------------------------------------------- +-- Simple set implementation based on LuaSocket's tinyirc.lua example +-- adds a FIFO queue for each value in the set +------------------------------------------------------------------------------- +local function newset() + local reverse = {} + local set = {} + local q = {} + setmetatable(set, { __index = { + insert = function(set, value) + if not reverse[value] then + set[#set + 1] = value + reverse[value] = #set + end + end, + + remove = function(set, value) + local index = reverse[value] + if index then + reverse[value] = nil + local top = set[#set] + set[#set] = nil + if top ~= value then + reverse[top] = index + set[index] = top + end + end + end, + + push = function (set, key, itm) + local qKey = q[key] + if qKey == nil then + q[key] = {itm} + else + qKey[#qKey + 1] = itm + end + end, + + pop = function (set, key) + local t = q[key] + if t ~= nil then + local ret = table.remove (t, 1) + if t[1] == nil then + q[key] = nil + end + return ret + end + end + }}) + return set +end + +local fnil = function()end +local _sleeping = { + times = {}, -- list with wake-up times + cos = {}, -- list with coroutines, index matches the 'times' list + lethargy = {}, -- list of coroutines sleeping without a wakeup time + + insert = fnil, + remove = fnil, + push = function(self, sleeptime, co) + if not co then return end + if sleeptime<0 then + --sleep until explicit wakeup through copas.wakeup + self.lethargy[co] = true + return + else + sleeptime = gettime() + sleeptime + end + local t, c = self.times, self.cos + local i, cou = 1, #t + --TODO: do a binary search + while i<=cou and t[i]<=sleeptime do i=i+1 end + table.insert(t, i, sleeptime) + table.insert(c, i, co) + end, + getnext = function(self) -- returns delay until next sleep expires, or nil if there is none + local t = self.times + local delay = t[1] and t[1] - gettime() or nil + + return delay and math.max(delay, 0) or nil + end, + -- find the thread that should wake up to the time + pop = function(self, time) + local t, c = self.times, self.cos + if #t==0 or time= self.timeout_time + end + } +end + +local _servers = newset() -- servers being handled +local _reading_log = {} +local _writing_log = {} + +local _reading = newset() -- sockets currently being read +local _writing = newset() -- sockets currently being written +local _isTimeout = { -- set of errors indicating a timeout + ["timeout"] = true, -- default LuaSocket timeout + ["wantread"] = true, -- LuaSec specific timeout + ["wantwrite"] = true, -- LuaSec specific timeout +} + +------------------------------------------------------------------------------- +-- Coroutine based socket I/O functions. +------------------------------------------------------------------------------- + +local function isTCP(socket) + return string.sub(tostring(socket),1,3) ~= "udp" +end + +-- reads a pattern from a client and yields to the reading set on timeouts +-- UDP: a UDP socket expects a second argument to be a number, so it MUST +-- be provided as the 'pattern' below defaults to a string. Will throw a +-- 'bad argument' error if omitted. +function copas.receive(client, pattern, part, timeout) + local s, err + pattern = pattern or "*l" + local current_log = _reading_log + local timer = newtimer(timeout) + repeat + s, err, part = client:receive(pattern, part) + if s or (not _isTimeout[err]) or timer:expired() then + current_log[client] = nil + return s, err, part + end + if err == "wantwrite" then + current_log = _writing_log + current_log[client] = timer + coroutine.yield(client, _writing) + else + current_log = _reading_log + current_log[client] = timer + coroutine.yield(client, _reading) + end + until false +end + +-- receives data from a client over UDP. Not available for TCP. +-- (this is a copy of receive() method, adapted for receivefrom() use) +function copas.receivefrom(client, size, timeout) + local s, err, port + size = size or UDP_DATAGRAM_MAX + local timer = newtimer(timeout) + repeat + s, err, port = client:receivefrom(size) -- upon success err holds ip address + if s or err ~= "timeout" or timer:expired() then + _reading_log[client] = nil + return s, err, port + end + _reading_log[client] = timer + coroutine.yield(client, _reading) + until false +end + +-- same as above but with special treatment when reading chunks, +-- unblocks on any data received. +function copas.receivePartial(client, pattern, part) + local s, err + pattern = pattern or "*l" + local current_log = _reading_log + local timer = newtimer() + repeat + s, err, part = client:receive(pattern, part) + if s or ((type(pattern)=="number") and part~="" and part ~=nil ) or (not _isTimeout[err]) or timer:expired() then + current_log[client] = nil + return s, err, part + end + if err == "wantwrite" then + current_log = _writing_log + current_log[client] = timer + coroutine.yield(client, _writing) + else + current_log = _reading_log + current_log[client] = timer + coroutine.yield(client, _reading) + end + until false +end + +-- sends data to a client. The operation is buffered and +-- yields to the writing set on timeouts +-- Note: from and to parameters will be ignored by/for UDP sockets +function copas.send(client, data, from, to, timeout) + local s, err + from = from or 1 + local lastIndex = from - 1 + local current_log = _writing_log + local timer = newtimer(timeout) + repeat + s, err, lastIndex = client:send(data, lastIndex + 1, to) + -- adds extra coroutine swap + -- garantees that high throughput doesn't take other threads to starvation + if (math.random(100) > 90) then + current_log[client] = timer -- TODO: how to handle this?? + if current_log == _writing_log then + coroutine.yield(client, _writing) + else + coroutine.yield(client, _reading) + end + end + if s or (not _isTimeout[err]) or timer:expired() then + current_log[client] = nil + return s, err,lastIndex + end + if err == "wantread" then + current_log = _reading_log + current_log[client] = timer + coroutine.yield(client, _reading) + else + current_log = _writing_log + current_log[client] = timer + coroutine.yield(client, _writing) + end + until false +end + +-- sends data to a client over UDP. Not available for TCP. +-- (this is a copy of send() method, adapted for sendto() use) +function copas.sendto(client, data, ip, port, timeout) + local s, err + local timer = newtimer(timeout) + repeat + s, err = client:sendto(data, ip, port) + -- adds extra coroutine swap + -- garantees that high throughput doesn't take other threads to starvation + if (math.random(100) > 90) then + _writing_log[client] = timer + coroutine.yield(client, _writing) + end + if s or err ~= "timeout" or timer:expired() then + _writing_log[client] = nil + return s, err + end + _writing_log[client] = timer + coroutine.yield(client, _writing) + until false +end + +-- waits until connection is completed +function copas.connect(skt, host, port, timeout) + skt:settimeout(0) + local ret, err, tried_more_than_once + local timer = newtimer(timeout) + repeat + ret, err = skt:connect(host, port) + if (not ret) and timer:expired() then + return ret, "timeout" + end + -- non-blocking connect on Windows results in error "Operation already + -- in progress" to indicate that it is completing the request async. So essentially + -- it is the same as "timeout" + -- "Invalid argument" explanation: https://github.com/diegonehab/luasocket/pull/190 + if ret or (err ~= "timeout" and err ~= "Operation already in progress" and err ~= "Invalid argument") then + -- Once the async connect completes, Windows returns the error "already connected" + -- to indicate it is done, so that error should be ignored. Except when it is the + -- first call to connect, then it was already connected to something else and the + -- error should be returned + if (not ret) and (err == "already connected" and tried_more_than_once) then + ret = 1 + err = nil + end + _writing_log[skt] = nil + return ret, err + end + tried_more_than_once = tried_more_than_once or true + _writing_log[skt] = timer + coroutine.yield(skt, _writing) + until false +end + +--- +-- Peforms an (async) ssl handshake on a connected TCP client socket. +-- NOTE: replace all previous socket references, with the returned new ssl wrapped socket +-- Throws error and does not return nil+error, as that might silently fail +-- in code like this; +-- copas.addserver(s1, function(skt) +-- skt = copas.wrap(skt, sparams) +-- skt:dohandshake() --> without explicit error checking, this fails silently and +-- skt:send(body) --> continues unencrypted +-- @param skt Regular LuaSocket CLIENT socket object +-- @param sslt Table with ssl parameters +-- @return wrapped ssl socket, or throws an error +function copas.dohandshake(skt, sslt) + ssl = ssl or require("ssl") + local nskt, err = ssl.wrap(skt, sslt) + if not nskt then return error(err) end + local queue + nskt:settimeout(0) + repeat + local success, err = nskt:dohandshake() + if success then + return nskt + elseif err == "wantwrite" then + queue = _writing + elseif err == "wantread" then + queue = _reading + else + error(err) + end + coroutine.yield(nskt, queue) + until false +end + +-- flushes a client write buffer (deprecated) +function copas.flush(client) +end + +-- wraps a TCP socket to use Copas methods (send, receive, flush and settimeout) +local _skt_mt_tcp = { + __tostring = function(self) + return tostring(self.socket).." (copas wrapped)" + end, + __index = { + + send = function (self, data, from, to) + return copas.send(self.socket, data, from, to, self.timeout) + end, + + receive = function (self, pattern, prefix) + if (self.timeout==0) then + return copas.receivePartial(self.socket, pattern, prefix) + end + return copas.receive(self.socket, pattern, prefix, self.timeout) + end, + + flush = function (self) + return copas.flush(self.socket) + end, + + settimeout = function (self, time) + self.timeout=time + return true + end, + + -- TODO: socket.connect is a shortcut, and must be provided with an alternative + -- if ssl parameters are available, it will also include a handshake + connect = function(self, host, port) + local res, err = copas.connect(self.socket, host, port, self.timeout) + if res and self.ssl_params then + res, err = self:dohandshake() + end + return res, err + end, + + close = function(self, ...) return self.socket:close(...) end, + + -- TODO: socket.bind is a shortcut, and must be provided with an alternative + bind = function(self, ...) return self.socket:bind(...) end, + + -- TODO: is this DNS related? hence blocking? + getsockname = function(self, ...) return self.socket:getsockname(...) end, + + getstats = function(self, ...) return self.socket:getstats(...) end, + + setstats = function(self, ...) return self.socket:setstats(...) end, + + listen = function(self, ...) return self.socket:listen(...) end, + + accept = function(self, ...) return self.socket:accept(...) end, + + setoption = function(self, ...) return self.socket:setoption(...) end, + + -- TODO: is this DNS related? hence blocking? + getpeername = function(self, ...) return self.socket:getpeername(...) end, + + shutdown = function(self, ...) return self.socket:shutdown(...) end, + + dohandshake = function(self, sslt) + self.ssl_params = sslt or self.ssl_params + local nskt, err = copas.dohandshake(self.socket, self.ssl_params) + if not nskt then return nskt, err end + self.socket = nskt -- replace internal socket with the newly wrapped ssl one + return self + end, + + }} + +-- wraps a UDP socket, copy of TCP one adapted for UDP. +local _skt_mt_udp = {__index = { }} +for k,v in pairs(_skt_mt_tcp) do _skt_mt_udp[k] = _skt_mt_udp[k] or v end +for k,v in pairs(_skt_mt_tcp.__index) do _skt_mt_udp.__index[k] = v end + +_skt_mt_udp.__index.sendto = function (self, data, ip, port) + -- UDP sending is non-blocking, but we provide starvation prevention, so replace anyway + return copas.sendto(self.socket, data, ip, port, self.timeout) + end + +_skt_mt_udp.__index.receive = function (self, size) + return copas.receive(self.socket, (size or UDP_DATAGRAM_MAX), nil, self.timeout) + end + +_skt_mt_udp.__index.receivefrom = function (self, size) + return copas.receivefrom(self.socket, (size or UDP_DATAGRAM_MAX), self.timeout) + end + + -- TODO: is this DNS related? hence blocking? +_skt_mt_udp.__index.setpeername = function(self, ...) return self.socket:getpeername(...) end + +_skt_mt_udp.__index.setsockname = function(self, ...) return self.socket:setsockname(...) end + + -- do not close client, as it is also the server for udp. +_skt_mt_udp.__index.close = function(self, ...) return true end + + +--- +-- Wraps a LuaSocket socket object in an async Copas based socket object. +-- @param skt The socket to wrap +-- @sslt (optional) Table with ssl parameters, use an empty table to use ssl with defaults +-- @return wrapped socket object +function copas.wrap (skt, sslt) + if (getmetatable(skt) == _skt_mt_tcp) or (getmetatable(skt) == _skt_mt_udp) then + return skt -- already wrapped + end + skt:settimeout(0) + if not isTCP(skt) then + return setmetatable ({socket = skt}, _skt_mt_udp) + else + return setmetatable ({socket = skt, ssl_params = sslt}, _skt_mt_tcp) + end +end + +--- Wraps a handler in a function that deals with wrapping the socket and doing the +-- optional ssl handshake. +function copas.handler(handler, sslparams) + return function (skt, ...) + skt = copas.wrap(skt) + if sslparams then skt:dohandshake(sslparams) end + return handler(skt, ...) + end +end + + +-------------------------------------------------- +-- Error handling +-------------------------------------------------- + +local _errhandlers = {} -- error handler per coroutine + +function copas.setErrorHandler (err) + local co = coroutine.running() + if co then + _errhandlers [co] = err + end +end + +local function _deferror (msg, co, skt) + print (msg, co, skt) +end + +------------------------------------------------------------------------------- +-- Thread handling +------------------------------------------------------------------------------- + +local function _doTick (co, skt, ...) + if not co then return end + + local ok, res, new_q = coroutine.resume(co, skt, ...) + if ok and res and new_q then + new_q:insert (res) + new_q:push (res, co) + else + if not ok then pcall (_errhandlers [co] or _deferror, res, co, skt) end + if skt and copas.autoclose and isTCP(skt) then + skt:close() -- do not auto-close UDP sockets, as the handler socket is also the server socket + end + _errhandlers [co] = nil + end +end + +-- accepts a connection on socket input +local function _accept(input, handler) + local client = input:accept() + if client then + client:settimeout(0) + local co = coroutine.create(handler) + _doTick (co, client) + --_reading:insert(client) + end + return client +end + +-- handle threads on a queue +local function _tickRead (skt) + _doTick (_reading:pop (skt), skt) +end + +local function _tickWrite (skt) + _doTick (_writing:pop (skt), skt) +end + +------------------------------------------------------------------------------- +-- Adds a server/handler pair to Copas dispatcher +------------------------------------------------------------------------------- +local function addTCPserver(server, handler, timeout) + server:settimeout(timeout or 0) + _servers[server] = handler + _reading:insert(server) +end + +local function addUDPserver(server, handler, timeout) + server:settimeout(timeout or 0) + local co = coroutine.create(handler) + _reading:insert(server) + _doTick (co, server) +end + +function copas.addserver(server, handler, timeout) + if isTCP(server) then + addTCPserver(server, handler, timeout) + else + addUDPserver(server, handler, timeout) + end +end + +function copas.removeserver(server, keep_open) + local s, mt = server, getmetatable(server) + if mt == _skt_mt_tcp or mt == _skt_mt_udp then + s = server.socket + end + _servers[s] = nil + _reading:remove(s) + if keep_open then + return true + end + return server:close() +end + +------------------------------------------------------------------------------- +-- Adds an new coroutine thread to Copas dispatcher +------------------------------------------------------------------------------- +function copas.addthread(handler, ...) + -- create a coroutine that skips the first argument, which is always the socket + -- passed by the scheduler, but `nil` in case of a task/thread + local thread = coroutine.create(function(_, ...) return handler(...) end) + _doTick (thread, nil, ...) + return thread +end + +------------------------------------------------------------------------------- +-- tasks registering +------------------------------------------------------------------------------- + +local _tasks = {} + +local function addtaskRead (tsk) + -- lets tasks call the default _tick() + tsk.def_tick = _tickRead + + _tasks [tsk] = true +end + +local function addtaskWrite (tsk) + -- lets tasks call the default _tick() + tsk.def_tick = _tickWrite + + _tasks [tsk] = true +end + +local function tasks () + return next, _tasks +end + +------------------------------------------------------------------------------- +-- main tasks: manage readable and writable socket sets +------------------------------------------------------------------------------- +-- a task to check ready to read events +local _readable_t = { + events = function(self) + local i = 0 + return function () + i = i + 1 + return self._evs [i] + end + end, + + tick = function (self, input) + local handler = _servers[input] + if handler then + input = _accept(input, handler) + else + _reading:remove (input) + self.def_tick (input) + end + end +} + +addtaskRead (_readable_t) + + +-- a task to check ready to write events +local _writable_t = { + events = function (self) + local i = 0 + return function () + i = i + 1 + return self._evs [i] + end + end, + + tick = function (self, output) + _writing:remove (output) + self.def_tick (output) + end +} + +addtaskWrite (_writable_t) +-- +--sleeping threads task +local _sleeping_t = { + tick = function (self, time, ...) + _doTick(_sleeping:pop(time), ...) + end +} + +-- yields the current coroutine and wakes it after 'sleeptime' seconds. +-- If sleeptime<0 then it sleeps until explicitly woken up using 'wakeup' +function copas.sleep(sleeptime) + coroutine.yield((sleeptime or 0), _sleeping) +end + +-- Wakes up a sleeping coroutine 'co'. +function copas.wakeup(co) + _sleeping:wakeup(co) +end + +local last_cleansing = 0 + +------------------------------------------------------------------------------- +-- Checks for reads and writes on sockets +------------------------------------------------------------------------------- +local function _select(timeout) + local r_evs, w_evs, err = socket.select(_reading, _writing, timeout) + _readable_t._evs, _writable_t._evs = r_evs, w_evs + + -- Check all sockets selected for reading, and check how long they have been waiting + -- for data already, without select returning them as readable + for skt, timer in pairs(_reading_log) do + if not r_evs[skt] and timer:expired() then + -- This one timedout while waiting to become readable, so move + -- it in the readable list and try and read anyway, despite not + -- having been returned by select + _reading_log[skt] = nil + r_evs[#r_evs + 1] = skt + r_evs[skt] = #r_evs + end + end + + -- Do the same for writing + for skt, timer in pairs(_writing_log) do + if not w_evs[skt] and timer:expired() then + _writing_log[skt] = nil + w_evs[#w_evs + 1] = skt + w_evs[skt] = #w_evs + end + end + + if err == "timeout" and #r_evs + #w_evs > 0 then + return nil + else + return err + end +end + + +------------------------------------------------------------------------------- +-- Dispatcher loop step. +-- Listen to client requests and handles them +-- Returns false if no data was handled (timeout), or true if there was data +-- handled (or nil + error message) +------------------------------------------------------------------------------- +function copas.step(timeout) + _sleeping_t:tick(gettime()) + + -- Need to wake up the select call in time for the next sleeping event + local nextwait = _sleeping:getnext() + if nextwait then + timeout = timeout and math.min(nextwait, timeout) or nextwait + else + if copas.finished() then + return false + end + end + + local err = _select (timeout) + if err then + if err == "timeout" then return false end + return nil, err + end + + for tsk in tasks() do + for ev in tsk:events() do + tsk:tick (ev) + end + end + return true +end + +------------------------------------------------------------------------------- +-- Check whether there is something to do. +-- returns false if there are no sockets for read/write nor tasks scheduled +-- (which means Copas is in an empty spin) +------------------------------------------------------------------------------- +function copas.finished() + return not (next(_reading) or next(_writing) or _sleeping:getnext()) +end + +------------------------------------------------------------------------------- +-- Dispatcher endless loop. +-- Listen to client requests and handles them forever +------------------------------------------------------------------------------- +function copas.loop(timeout) + copas.running = true + while not copas.finished() do copas.step(timeout) end + copas.running = false +end + +return copas diff --git a/copas/.travis.yml b/copas/.travis.yml new file mode 100644 index 0000000..19db950 --- /dev/null +++ b/copas/.travis.yml @@ -0,0 +1,30 @@ +language: python +sudo: false + +env: + - LUA="lua=5.1" + - LUA="lua=5.2" + - LUA="lua=5.3" + - LUA="luajit=2.0" + - LUA="luajit=2.1" + +before_install: + - pip install hererocks + - hererocks lua_install -r^ --$LUA + - export PATH=$PATH:$PWD/lua_install/bin # Add directory with all installed binaries to PATH + +install: + - luarocks install luasec # optional dependency + - luarocks install luacov + - luarocks install luacov-coveralls + +script: + - luarocks make $(ls rockspec/copas-[a-z]* | sort -r | head -n 1) # get latest development rockspec + - make coverage + +after_success: + - luacov-coveralls --exclude $TRAVIS_BUILD_DIR/lua_install --exclude tests + +branches: + except: + - gh-pages diff --git a/copas/ftp.lua b/copas/ftp.lua new file mode 100644 index 0000000..109161f --- /dev/null +++ b/copas/ftp.lua @@ -0,0 +1,95 @@ +------------------------------------------------------------------- +-- identical to the socket.ftp module except that it uses +-- async wrapped Copas sockets + +local copas = require("copas") +local socket = require("socket") +local ftp = require("socket.ftp") +local ltn12 = require("ltn12") +local url = require("socket.url") + + +local create = function() return copas.wrap(socket.tcp()) end +local forwards = { -- setting these will be forwarded to the original smtp module + PORT = true, + TIMEOUT = true, + PASSWORD = true, + USER = true +} + +copas.ftp = setmetatable({}, { + -- use original module as metatable, to lookup constants like socket.TIMEOUT, etc. + __index = ftp, + -- Setting constants is forwarded to the luasocket.ftp module. + __newindex = function(self, key, value) + if forwards[key] then ftp[key] = value return end + return rawset(self, key, value) + end, + }) +local _M = copas.ftp + +---[[ copy of Luasocket stuff here untile PR #133 is accepted +-- a copy of the version in LuaSockets' ftp.lua +-- no 'create' can be passed in the string form, hence a local copy here +local default = { + path = "/", + scheme = "ftp" +} + +-- a copy of the version in LuaSockets' ftp.lua +-- no 'create' can be passed in the string form, hence a local copy here +local function parse(u) + local t = socket.try(url.parse(u, default)) + socket.try(t.scheme == "ftp", "wrong scheme '" .. t.scheme .. "'") + socket.try(t.host, "missing hostname") + local pat = "^type=(.)$" + if t.params then + t.type = socket.skip(2, string.find(t.params, pat)) + socket.try(t.type == "a" or t.type == "i", + "invalid type '" .. t.type .. "'") + end + return t +end + +-- parses a simple form into the advanced form +-- if `body` is provided, a PUT, otherwise a GET. +-- If GET, then a field `target` is added to store the results +_M.parseRequest = function(u, body) + local t = parse(u) + if body then + t.source = ltn12.source.string(body) + else + t.target = {} + t.sink = ltn12.sink.table(t.target) + end +end +--]] + +_M.put = socket.protect(function(putt, body) + if type(putt) == "string" then + putt = _M.parseRequest(putt, body) + _M.put(putt) + return table.concat(putt.target) + else + putt.create = putt.create or create + return ftp.put(putt) + end +end) + +_M.get = socket.protect(function(gett) + if type(gett) == "string" then + gett = _M.parseRequest(gett) + _M.get(gett) + return table.concat(gett.target) + else + gett.create = gett.create or create + return ftp.get(gett) + end +end) + +_M.command = function(cmdt) + cmdt.create = cmdt.create or create + return ftp.command(cmdt) +end + +return _M diff --git a/copas/ftp.lua~ b/copas/ftp.lua~ new file mode 100644 index 0000000..109161f --- /dev/null +++ b/copas/ftp.lua~ @@ -0,0 +1,95 @@ +------------------------------------------------------------------- +-- identical to the socket.ftp module except that it uses +-- async wrapped Copas sockets + +local copas = require("copas") +local socket = require("socket") +local ftp = require("socket.ftp") +local ltn12 = require("ltn12") +local url = require("socket.url") + + +local create = function() return copas.wrap(socket.tcp()) end +local forwards = { -- setting these will be forwarded to the original smtp module + PORT = true, + TIMEOUT = true, + PASSWORD = true, + USER = true +} + +copas.ftp = setmetatable({}, { + -- use original module as metatable, to lookup constants like socket.TIMEOUT, etc. + __index = ftp, + -- Setting constants is forwarded to the luasocket.ftp module. + __newindex = function(self, key, value) + if forwards[key] then ftp[key] = value return end + return rawset(self, key, value) + end, + }) +local _M = copas.ftp + +---[[ copy of Luasocket stuff here untile PR #133 is accepted +-- a copy of the version in LuaSockets' ftp.lua +-- no 'create' can be passed in the string form, hence a local copy here +local default = { + path = "/", + scheme = "ftp" +} + +-- a copy of the version in LuaSockets' ftp.lua +-- no 'create' can be passed in the string form, hence a local copy here +local function parse(u) + local t = socket.try(url.parse(u, default)) + socket.try(t.scheme == "ftp", "wrong scheme '" .. t.scheme .. "'") + socket.try(t.host, "missing hostname") + local pat = "^type=(.)$" + if t.params then + t.type = socket.skip(2, string.find(t.params, pat)) + socket.try(t.type == "a" or t.type == "i", + "invalid type '" .. t.type .. "'") + end + return t +end + +-- parses a simple form into the advanced form +-- if `body` is provided, a PUT, otherwise a GET. +-- If GET, then a field `target` is added to store the results +_M.parseRequest = function(u, body) + local t = parse(u) + if body then + t.source = ltn12.source.string(body) + else + t.target = {} + t.sink = ltn12.sink.table(t.target) + end +end +--]] + +_M.put = socket.protect(function(putt, body) + if type(putt) == "string" then + putt = _M.parseRequest(putt, body) + _M.put(putt) + return table.concat(putt.target) + else + putt.create = putt.create or create + return ftp.put(putt) + end +end) + +_M.get = socket.protect(function(gett) + if type(gett) == "string" then + gett = _M.parseRequest(gett) + _M.get(gett) + return table.concat(gett.target) + else + gett.create = gett.create or create + return ftp.get(gett) + end +end) + +_M.command = function(cmdt) + cmdt.create = cmdt.create or create + return ftp.command(cmdt) +end + +return _M diff --git a/copas/ftp.lua~~ b/copas/ftp.lua~~ new file mode 100644 index 0000000..109161f --- /dev/null +++ b/copas/ftp.lua~~ @@ -0,0 +1,95 @@ +------------------------------------------------------------------- +-- identical to the socket.ftp module except that it uses +-- async wrapped Copas sockets + +local copas = require("copas") +local socket = require("socket") +local ftp = require("socket.ftp") +local ltn12 = require("ltn12") +local url = require("socket.url") + + +local create = function() return copas.wrap(socket.tcp()) end +local forwards = { -- setting these will be forwarded to the original smtp module + PORT = true, + TIMEOUT = true, + PASSWORD = true, + USER = true +} + +copas.ftp = setmetatable({}, { + -- use original module as metatable, to lookup constants like socket.TIMEOUT, etc. + __index = ftp, + -- Setting constants is forwarded to the luasocket.ftp module. + __newindex = function(self, key, value) + if forwards[key] then ftp[key] = value return end + return rawset(self, key, value) + end, + }) +local _M = copas.ftp + +---[[ copy of Luasocket stuff here untile PR #133 is accepted +-- a copy of the version in LuaSockets' ftp.lua +-- no 'create' can be passed in the string form, hence a local copy here +local default = { + path = "/", + scheme = "ftp" +} + +-- a copy of the version in LuaSockets' ftp.lua +-- no 'create' can be passed in the string form, hence a local copy here +local function parse(u) + local t = socket.try(url.parse(u, default)) + socket.try(t.scheme == "ftp", "wrong scheme '" .. t.scheme .. "'") + socket.try(t.host, "missing hostname") + local pat = "^type=(.)$" + if t.params then + t.type = socket.skip(2, string.find(t.params, pat)) + socket.try(t.type == "a" or t.type == "i", + "invalid type '" .. t.type .. "'") + end + return t +end + +-- parses a simple form into the advanced form +-- if `body` is provided, a PUT, otherwise a GET. +-- If GET, then a field `target` is added to store the results +_M.parseRequest = function(u, body) + local t = parse(u) + if body then + t.source = ltn12.source.string(body) + else + t.target = {} + t.sink = ltn12.sink.table(t.target) + end +end +--]] + +_M.put = socket.protect(function(putt, body) + if type(putt) == "string" then + putt = _M.parseRequest(putt, body) + _M.put(putt) + return table.concat(putt.target) + else + putt.create = putt.create or create + return ftp.put(putt) + end +end) + +_M.get = socket.protect(function(gett) + if type(gett) == "string" then + gett = _M.parseRequest(gett) + _M.get(gett) + return table.concat(gett.target) + else + gett.create = gett.create or create + return ftp.get(gett) + end +end) + +_M.command = function(cmdt) + cmdt.create = cmdt.create or create + return ftp.command(cmdt) +end + +return _M diff --git a/copas/http.lua b/copas/http.lua new file mode 100644 index 0000000..89246af --- /dev/null +++ b/copas/http.lua @@ -0,0 +1,413 @@ +----------------------------------------------------------------------------- +-- Full copy of the LuaSocket code, modified to include +-- https and http/https redirects, and Copas async enabled. +----------------------------------------------------------------------------- +-- HTTP/1.1 client support for the Lua language. +-- LuaSocket toolkit. +-- Author: Diego Nehab +----------------------------------------------------------------------------- + +----------------------------------------------------------------------------- +-- Declare module and import dependencies +------------------------------------------------------------------------------- +local socket = require("socket") +local url = require("socket.url") +local ltn12 = require("ltn12") +local mime = require("mime") +local string = require("string") +local headers = require("socket.headers") +local base = _G +local table = require("table") +local try = socket.try +local copas = require("copas") +copas.http = {} +local _M = copas.http + +----------------------------------------------------------------------------- +-- Program constants +----------------------------------------------------------------------------- +-- connection timeout in seconds +_M.TIMEOUT = 60 +-- default port for document retrieval +_M.PORT = 80 +-- user agent field sent in request +_M.USERAGENT = socket._VERSION + +-- Default settings for SSL +_M.SSLPORT = 443 +_M.SSLPROTOCOL = "tlsv1_2" +_M.SSLOPTIONS = "all" +_M.SSLVERIFY = "none" + + +----------------------------------------------------------------------------- +-- Reads MIME headers from a connection, unfolding where needed +----------------------------------------------------------------------------- +local function receiveheaders(sock, headers) + local line, name, value, err + headers = headers or {} + -- get first line + line, err = sock:receive() + if err then return nil, err end + -- headers go until a blank line is found + while line ~= "" do + -- get field-name and value + name, value = socket.skip(2, string.find(line, "^(.-):%s*(.*)")) + if not (name and value) then return nil, "malformed reponse headers" end + name = string.lower(name) + -- get next line (value might be folded) + line, err = sock:receive() + if err then return nil, err end + -- unfold any folded values + while string.find(line, "^%s") do + value = value .. line + line = sock:receive() + if err then return nil, err end + end + -- save pair in table + if headers[name] then headers[name] = headers[name] .. ", " .. value + else headers[name] = value end + end + return headers +end + +----------------------------------------------------------------------------- +-- Extra sources and sinks +----------------------------------------------------------------------------- +socket.sourcet["http-chunked"] = function(sock, headers) + return base.setmetatable({ + getfd = function() return sock:getfd() end, + dirty = function() return sock:dirty() end + }, { + __call = function() + -- get chunk size, skip extention + local line, err = sock:receive() + if err then return nil, err end + local size = base.tonumber(string.gsub(line, ";.*", ""), 16) + if not size then return nil, "invalid chunk size" end + -- was it the last chunk? + if size > 0 then + -- if not, get chunk and skip terminating CRLF + local chunk, err = sock:receive(size) + if chunk then sock:receive() end + return chunk, err + else + -- if it was, read trailers into headers table + headers, err = receiveheaders(sock, headers) + if not headers then return nil, err end + end + end + }) +end + +socket.sinkt["http-chunked"] = function(sock) + return base.setmetatable({ + getfd = function() return sock:getfd() end, + dirty = function() return sock:dirty() end + }, { + __call = function(self, chunk, err) + if not chunk then return sock:send("0\r\n\r\n") end + local size = string.format("%X\r\n", string.len(chunk)) + return sock:send(size .. chunk .. "\r\n") + end + }) +end + +----------------------------------------------------------------------------- +-- Low level HTTP API +----------------------------------------------------------------------------- +local metat = { __index = {} } + +function _M.open(reqt) + -- create socket with user connect function + local c = socket.try(reqt:create()) -- method call, passing reqt table as self! + local h = base.setmetatable({ c = c }, metat) + -- create finalized try + h.try = socket.newtry(function() h:close() end) + -- set timeout before connecting + h.try(c:settimeout(_M.TIMEOUT)) + h.try(c:connect(reqt.host, reqt.port or _M.PORT)) + -- here everything worked + return h +end + +function metat.__index:sendrequestline(method, uri) + local reqline = string.format("%s %s HTTP/1.1\r\n", method or "GET", uri) + return self.try(self.c:send(reqline)) +end + +function metat.__index:sendheaders(tosend) + local canonic = headers.canonic + local h = "\r\n" + for f, v in base.pairs(tosend) do + h = (canonic[f] or f) .. ": " .. v .. "\r\n" .. h + end + self.try(self.c:send(h)) + return 1 +end + +function metat.__index:sendbody(headers, source, step) + source = source or ltn12.source.empty() + step = step or ltn12.pump.step + -- if we don't know the size in advance, send chunked and hope for the best + local mode = "http-chunked" + if headers["content-length"] then mode = "keep-open" end + return self.try(ltn12.pump.all(source, socket.sink(mode, self.c), step)) +end + +function metat.__index:receivestatusline() + local status = self.try(self.c:receive(5)) + -- identify HTTP/0.9 responses, which do not contain a status line + -- this is just a heuristic, but is what the RFC recommends + if status ~= "HTTP/" then return nil, status end + -- otherwise proceed reading a status line + status = self.try(self.c:receive("*l", status)) + local code = socket.skip(2, string.find(status, "HTTP/%d*%.%d* (%d%d%d)")) + return self.try(base.tonumber(code), status) +end + +function metat.__index:receiveheaders() + return self.try(receiveheaders(self.c)) +end + +function metat.__index:receivebody(headers, sink, step) + sink = sink or ltn12.sink.null() + step = step or ltn12.pump.step + local length = base.tonumber(headers["content-length"]) + local t = headers["transfer-encoding"] -- shortcut + local mode = "default" -- connection close + if t and t ~= "identity" then mode = "http-chunked" + elseif base.tonumber(headers["content-length"]) then mode = "by-length" end + return self.try(ltn12.pump.all(socket.source(mode, self.c, length), + sink, step)) +end + +function metat.__index:receive09body(status, sink, step) + local source = ltn12.source.rewind(socket.source("until-closed", self.c)) + source(status) + return self.try(ltn12.pump.all(source, sink, step)) +end + +function metat.__index:close() + return self.c:close() +end + +----------------------------------------------------------------------------- +-- High level HTTP API +----------------------------------------------------------------------------- +local function adjusturi(reqt) + local u = reqt + -- if there is a proxy, we need the full url. otherwise, just a part. + if not reqt.proxy and not _M.PROXY then + u = { + path = socket.try(reqt.path, "invalid path 'nil'"), + params = reqt.params, + query = reqt.query, + fragment = reqt.fragment + } + end + return url.build(u) +end + +local function adjustproxy(reqt) + local proxy = reqt.proxy or _M.PROXY + if proxy then + proxy = url.parse(proxy) + return proxy.host, proxy.port or 3128 + else + return reqt.host, reqt.port + end +end + +local function adjustheaders(reqt) + -- default headers + local host = string.gsub(reqt.authority, "^.-@", "") + local lower = { + ["user-agent"] = _M.USERAGENT, + ["host"] = host, + ["connection"] = "close, TE", + ["te"] = "trailers" + } + -- if we have authentication information, pass it along + if reqt.user and reqt.password then + lower["authorization"] = + "Basic " .. (mime.b64(reqt.user .. ":" .. reqt.password)) + end + -- override with user headers + for i,v in base.pairs(reqt.headers or lower) do + lower[string.lower(i)] = v + end + return lower +end + +-- default url parts +local default = { + host = "", + port = _M.PORT, + path ="/", + scheme = "http" +} + +local function adjustrequest(reqt) + -- parse url if provided + local nreqt = reqt.url and url.parse(reqt.url, default) or {} + -- explicit components override url + for i,v in base.pairs(reqt) do nreqt[i] = v end + if nreqt.port == "" then nreqt.port = 80 end + socket.try(nreqt.host and nreqt.host ~= "", + "invalid host '" .. base.tostring(nreqt.host) .. "'") + -- compute uri if user hasn't overriden + nreqt.uri = reqt.uri or adjusturi(nreqt) + -- ajust host and port if there is a proxy + nreqt.host, nreqt.port = adjustproxy(nreqt) + -- adjust headers in request + nreqt.headers = adjustheaders(nreqt) + return nreqt +end + +local function shouldredirect(reqt, code, headers) + return headers.location and + string.gsub(headers.location, "%s", "") ~= "" and + (reqt.redirect ~= false) and + (code == 301 or code == 302 or code == 303 or code == 307) and + (not reqt.method or reqt.method == "GET" or reqt.method == "HEAD") + and (not reqt.nredirects or reqt.nredirects < 5) +end + +local function shouldreceivebody(reqt, code) + if reqt.method == "HEAD" then return nil end + if code == 204 or code == 304 then return nil end + if code >= 100 and code < 200 then return nil end + return 1 +end + +-- forward declarations +local trequest, tredirect + +--[[local]] function tredirect(reqt, location) + local result, code, headers, status = trequest { + -- the RFC says the redirect URL has to be absolute, but some + -- servers do not respect that + url = url.absolute(reqt.url, location), + source = reqt.source, + sink = reqt.sink, + headers = reqt.headers, + proxy = reqt.proxy, + nredirects = (reqt.nredirects or 0) + 1, + create = reqt.create + } + -- pass location header back as a hint we redirected + headers = headers or {} + headers.location = headers.location or location + return result, code, headers, status +end + +--[[local]] function trequest(reqt) + -- we loop until we get what we want, or + -- until we are sure there is no way to get it + local nreqt = adjustrequest(reqt) + local h = _M.open(nreqt) + -- send request line and headers + h:sendrequestline(nreqt.method, nreqt.uri) + h:sendheaders(nreqt.headers) + -- if there is a body, send it + if nreqt.source then + h:sendbody(nreqt.headers, nreqt.source, nreqt.step) + end + local code, status = h:receivestatusline() + -- if it is an HTTP/0.9 server, simply get the body and we are done + if not code then + h:receive09body(status, nreqt.sink, nreqt.step) + return 1, 200 + end + local headers + -- ignore any 100-continue messages + while code == 100 do + headers = h:receiveheaders() + code, status = h:receivestatusline() + end + headers = h:receiveheaders() + -- at this point we should have a honest reply from the server + -- we can't redirect if we already used the source, so we report the error + if shouldredirect(nreqt, code, headers) and not nreqt.source then + h:close() + return tredirect(reqt, headers.location) + end + -- here we are finally done + if shouldreceivebody(nreqt, code) then + h:receivebody(headers, nreqt.sink, nreqt.step) + end + h:close() + return 1, code, headers, status +end + +-- Return a function which performs the SSL/TLS connection. +local function tcp(params) + params = params or {} + -- Default settings + params.protocol = params.protocol or _M.SSLPROTOCOL + params.options = params.options or _M.SSLOPTIONS + params.verify = params.verify or _M.SSLVERIFY + params.mode = "client" -- Force client mode + -- upvalue to track https -> http redirection + local washttps = false + -- 'create' function for LuaSocket + return function (reqt) + local u = url.parse(reqt.url) + if (reqt.scheme or u.scheme) == "https" then + -- https, provide an ssl wrapped socket + local conn = copas.wrap(socket.tcp(), params) + -- insert https default port, overriding http port inserted by LuaSocket + if not u.port then + u.port = _M.SSLPORT + reqt.url = url.build(u) + reqt.port = _M.SSLPORT + end + washttps = true + return conn + else + -- regular http, needs just a socket... + if washttps and params.redirect ~= "all" then + try(nil, "Unallowed insecure redirect https to http") + end + return copas.wrap(socket.tcp()) + end + end +end + +-- parses a shorthand form into the advanced table form. +-- adds field `target` to the table. This will hold the return values. +_M.parseRequest = function(u, b) + local reqt = { + url = u, + target = {}, + } + reqt.sink = ltn12.sink.table(reqt.target) + if b then + reqt.source = ltn12.source.string(b) + reqt.headers = { + ["content-length"] = string.len(b), + ["content-type"] = "application/x-www-form-urlencoded" + } + reqt.method = "POST" + end + return reqt +end + +_M.request = socket.protect(function(reqt, body) + if base.type(reqt) == "string" then + reqt = _M.parseRequest(reqt, body) + local ok, code, headers, status = _M.request(reqt) + + if ok then + return table.concat(reqt.target), code, headers, status + else + return nil, code + end + else + reqt.create = reqt.create or tcp(reqt) + return trequest(reqt) + end +end) + +return _M diff --git a/copas/http.lua~ b/copas/http.lua~ new file mode 100644 index 0000000..89246af --- /dev/null +++ b/copas/http.lua~ @@ -0,0 +1,413 @@ +----------------------------------------------------------------------------- +-- Full copy of the LuaSocket code, modified to include +-- https and http/https redirects, and Copas async enabled. +----------------------------------------------------------------------------- +-- HTTP/1.1 client support for the Lua language. +-- LuaSocket toolkit. +-- Author: Diego Nehab +----------------------------------------------------------------------------- + +----------------------------------------------------------------------------- +-- Declare module and import dependencies +------------------------------------------------------------------------------- +local socket = require("socket") +local url = require("socket.url") +local ltn12 = require("ltn12") +local mime = require("mime") +local string = require("string") +local headers = require("socket.headers") +local base = _G +local table = require("table") +local try = socket.try +local copas = require("copas") +copas.http = {} +local _M = copas.http + +----------------------------------------------------------------------------- +-- Program constants +----------------------------------------------------------------------------- +-- connection timeout in seconds +_M.TIMEOUT = 60 +-- default port for document retrieval +_M.PORT = 80 +-- user agent field sent in request +_M.USERAGENT = socket._VERSION + +-- Default settings for SSL +_M.SSLPORT = 443 +_M.SSLPROTOCOL = "tlsv1_2" +_M.SSLOPTIONS = "all" +_M.SSLVERIFY = "none" + + +----------------------------------------------------------------------------- +-- Reads MIME headers from a connection, unfolding where needed +----------------------------------------------------------------------------- +local function receiveheaders(sock, headers) + local line, name, value, err + headers = headers or {} + -- get first line + line, err = sock:receive() + if err then return nil, err end + -- headers go until a blank line is found + while line ~= "" do + -- get field-name and value + name, value = socket.skip(2, string.find(line, "^(.-):%s*(.*)")) + if not (name and value) then return nil, "malformed reponse headers" end + name = string.lower(name) + -- get next line (value might be folded) + line, err = sock:receive() + if err then return nil, err end + -- unfold any folded values + while string.find(line, "^%s") do + value = value .. line + line = sock:receive() + if err then return nil, err end + end + -- save pair in table + if headers[name] then headers[name] = headers[name] .. ", " .. value + else headers[name] = value end + end + return headers +end + +----------------------------------------------------------------------------- +-- Extra sources and sinks +----------------------------------------------------------------------------- +socket.sourcet["http-chunked"] = function(sock, headers) + return base.setmetatable({ + getfd = function() return sock:getfd() end, + dirty = function() return sock:dirty() end + }, { + __call = function() + -- get chunk size, skip extention + local line, err = sock:receive() + if err then return nil, err end + local size = base.tonumber(string.gsub(line, ";.*", ""), 16) + if not size then return nil, "invalid chunk size" end + -- was it the last chunk? + if size > 0 then + -- if not, get chunk and skip terminating CRLF + local chunk, err = sock:receive(size) + if chunk then sock:receive() end + return chunk, err + else + -- if it was, read trailers into headers table + headers, err = receiveheaders(sock, headers) + if not headers then return nil, err end + end + end + }) +end + +socket.sinkt["http-chunked"] = function(sock) + return base.setmetatable({ + getfd = function() return sock:getfd() end, + dirty = function() return sock:dirty() end + }, { + __call = function(self, chunk, err) + if not chunk then return sock:send("0\r\n\r\n") end + local size = string.format("%X\r\n", string.len(chunk)) + return sock:send(size .. chunk .. "\r\n") + end + }) +end + +----------------------------------------------------------------------------- +-- Low level HTTP API +----------------------------------------------------------------------------- +local metat = { __index = {} } + +function _M.open(reqt) + -- create socket with user connect function + local c = socket.try(reqt:create()) -- method call, passing reqt table as self! + local h = base.setmetatable({ c = c }, metat) + -- create finalized try + h.try = socket.newtry(function() h:close() end) + -- set timeout before connecting + h.try(c:settimeout(_M.TIMEOUT)) + h.try(c:connect(reqt.host, reqt.port or _M.PORT)) + -- here everything worked + return h +end + +function metat.__index:sendrequestline(method, uri) + local reqline = string.format("%s %s HTTP/1.1\r\n", method or "GET", uri) + return self.try(self.c:send(reqline)) +end + +function metat.__index:sendheaders(tosend) + local canonic = headers.canonic + local h = "\r\n" + for f, v in base.pairs(tosend) do + h = (canonic[f] or f) .. ": " .. v .. "\r\n" .. h + end + self.try(self.c:send(h)) + return 1 +end + +function metat.__index:sendbody(headers, source, step) + source = source or ltn12.source.empty() + step = step or ltn12.pump.step + -- if we don't know the size in advance, send chunked and hope for the best + local mode = "http-chunked" + if headers["content-length"] then mode = "keep-open" end + return self.try(ltn12.pump.all(source, socket.sink(mode, self.c), step)) +end + +function metat.__index:receivestatusline() + local status = self.try(self.c:receive(5)) + -- identify HTTP/0.9 responses, which do not contain a status line + -- this is just a heuristic, but is what the RFC recommends + if status ~= "HTTP/" then return nil, status end + -- otherwise proceed reading a status line + status = self.try(self.c:receive("*l", status)) + local code = socket.skip(2, string.find(status, "HTTP/%d*%.%d* (%d%d%d)")) + return self.try(base.tonumber(code), status) +end + +function metat.__index:receiveheaders() + return self.try(receiveheaders(self.c)) +end + +function metat.__index:receivebody(headers, sink, step) + sink = sink or ltn12.sink.null() + step = step or ltn12.pump.step + local length = base.tonumber(headers["content-length"]) + local t = headers["transfer-encoding"] -- shortcut + local mode = "default" -- connection close + if t and t ~= "identity" then mode = "http-chunked" + elseif base.tonumber(headers["content-length"]) then mode = "by-length" end + return self.try(ltn12.pump.all(socket.source(mode, self.c, length), + sink, step)) +end + +function metat.__index:receive09body(status, sink, step) + local source = ltn12.source.rewind(socket.source("until-closed", self.c)) + source(status) + return self.try(ltn12.pump.all(source, sink, step)) +end + +function metat.__index:close() + return self.c:close() +end + +----------------------------------------------------------------------------- +-- High level HTTP API +----------------------------------------------------------------------------- +local function adjusturi(reqt) + local u = reqt + -- if there is a proxy, we need the full url. otherwise, just a part. + if not reqt.proxy and not _M.PROXY then + u = { + path = socket.try(reqt.path, "invalid path 'nil'"), + params = reqt.params, + query = reqt.query, + fragment = reqt.fragment + } + end + return url.build(u) +end + +local function adjustproxy(reqt) + local proxy = reqt.proxy or _M.PROXY + if proxy then + proxy = url.parse(proxy) + return proxy.host, proxy.port or 3128 + else + return reqt.host, reqt.port + end +end + +local function adjustheaders(reqt) + -- default headers + local host = string.gsub(reqt.authority, "^.-@", "") + local lower = { + ["user-agent"] = _M.USERAGENT, + ["host"] = host, + ["connection"] = "close, TE", + ["te"] = "trailers" + } + -- if we have authentication information, pass it along + if reqt.user and reqt.password then + lower["authorization"] = + "Basic " .. (mime.b64(reqt.user .. ":" .. reqt.password)) + end + -- override with user headers + for i,v in base.pairs(reqt.headers or lower) do + lower[string.lower(i)] = v + end + return lower +end + +-- default url parts +local default = { + host = "", + port = _M.PORT, + path ="/", + scheme = "http" +} + +local function adjustrequest(reqt) + -- parse url if provided + local nreqt = reqt.url and url.parse(reqt.url, default) or {} + -- explicit components override url + for i,v in base.pairs(reqt) do nreqt[i] = v end + if nreqt.port == "" then nreqt.port = 80 end + socket.try(nreqt.host and nreqt.host ~= "", + "invalid host '" .. base.tostring(nreqt.host) .. "'") + -- compute uri if user hasn't overriden + nreqt.uri = reqt.uri or adjusturi(nreqt) + -- ajust host and port if there is a proxy + nreqt.host, nreqt.port = adjustproxy(nreqt) + -- adjust headers in request + nreqt.headers = adjustheaders(nreqt) + return nreqt +end + +local function shouldredirect(reqt, code, headers) + return headers.location and + string.gsub(headers.location, "%s", "") ~= "" and + (reqt.redirect ~= false) and + (code == 301 or code == 302 or code == 303 or code == 307) and + (not reqt.method or reqt.method == "GET" or reqt.method == "HEAD") + and (not reqt.nredirects or reqt.nredirects < 5) +end + +local function shouldreceivebody(reqt, code) + if reqt.method == "HEAD" then return nil end + if code == 204 or code == 304 then return nil end + if code >= 100 and code < 200 then return nil end + return 1 +end + +-- forward declarations +local trequest, tredirect + +--[[local]] function tredirect(reqt, location) + local result, code, headers, status = trequest { + -- the RFC says the redirect URL has to be absolute, but some + -- servers do not respect that + url = url.absolute(reqt.url, location), + source = reqt.source, + sink = reqt.sink, + headers = reqt.headers, + proxy = reqt.proxy, + nredirects = (reqt.nredirects or 0) + 1, + create = reqt.create + } + -- pass location header back as a hint we redirected + headers = headers or {} + headers.location = headers.location or location + return result, code, headers, status +end + +--[[local]] function trequest(reqt) + -- we loop until we get what we want, or + -- until we are sure there is no way to get it + local nreqt = adjustrequest(reqt) + local h = _M.open(nreqt) + -- send request line and headers + h:sendrequestline(nreqt.method, nreqt.uri) + h:sendheaders(nreqt.headers) + -- if there is a body, send it + if nreqt.source then + h:sendbody(nreqt.headers, nreqt.source, nreqt.step) + end + local code, status = h:receivestatusline() + -- if it is an HTTP/0.9 server, simply get the body and we are done + if not code then + h:receive09body(status, nreqt.sink, nreqt.step) + return 1, 200 + end + local headers + -- ignore any 100-continue messages + while code == 100 do + headers = h:receiveheaders() + code, status = h:receivestatusline() + end + headers = h:receiveheaders() + -- at this point we should have a honest reply from the server + -- we can't redirect if we already used the source, so we report the error + if shouldredirect(nreqt, code, headers) and not nreqt.source then + h:close() + return tredirect(reqt, headers.location) + end + -- here we are finally done + if shouldreceivebody(nreqt, code) then + h:receivebody(headers, nreqt.sink, nreqt.step) + end + h:close() + return 1, code, headers, status +end + +-- Return a function which performs the SSL/TLS connection. +local function tcp(params) + params = params or {} + -- Default settings + params.protocol = params.protocol or _M.SSLPROTOCOL + params.options = params.options or _M.SSLOPTIONS + params.verify = params.verify or _M.SSLVERIFY + params.mode = "client" -- Force client mode + -- upvalue to track https -> http redirection + local washttps = false + -- 'create' function for LuaSocket + return function (reqt) + local u = url.parse(reqt.url) + if (reqt.scheme or u.scheme) == "https" then + -- https, provide an ssl wrapped socket + local conn = copas.wrap(socket.tcp(), params) + -- insert https default port, overriding http port inserted by LuaSocket + if not u.port then + u.port = _M.SSLPORT + reqt.url = url.build(u) + reqt.port = _M.SSLPORT + end + washttps = true + return conn + else + -- regular http, needs just a socket... + if washttps and params.redirect ~= "all" then + try(nil, "Unallowed insecure redirect https to http") + end + return copas.wrap(socket.tcp()) + end + end +end + +-- parses a shorthand form into the advanced table form. +-- adds field `target` to the table. This will hold the return values. +_M.parseRequest = function(u, b) + local reqt = { + url = u, + target = {}, + } + reqt.sink = ltn12.sink.table(reqt.target) + if b then + reqt.source = ltn12.source.string(b) + reqt.headers = { + ["content-length"] = string.len(b), + ["content-type"] = "application/x-www-form-urlencoded" + } + reqt.method = "POST" + end + return reqt +end + +_M.request = socket.protect(function(reqt, body) + if base.type(reqt) == "string" then + reqt = _M.parseRequest(reqt, body) + local ok, code, headers, status = _M.request(reqt) + + if ok then + return table.concat(reqt.target), code, headers, status + else + return nil, code + end + else + reqt.create = reqt.create or tcp(reqt) + return trequest(reqt) + end +end) + +return _M diff --git a/copas/http.lua~~ b/copas/http.lua~~ new file mode 100644 index 0000000..89246af --- /dev/null +++ b/copas/http.lua~~ @@ -0,0 +1,413 @@ +----------------------------------------------------------------------------- +-- Full copy of the LuaSocket code, modified to include +-- https and http/https redirects, and Copas async enabled. +----------------------------------------------------------------------------- +-- HTTP/1.1 client support for the Lua language. +-- LuaSocket toolkit. +-- Author: Diego Nehab +----------------------------------------------------------------------------- + +----------------------------------------------------------------------------- +-- Declare module and import dependencies +------------------------------------------------------------------------------- +local socket = require("socket") +local url = require("socket.url") +local ltn12 = require("ltn12") +local mime = require("mime") +local string = require("string") +local headers = require("socket.headers") +local base = _G +local table = require("table") +local try = socket.try +local copas = require("copas") +copas.http = {} +local _M = copas.http + +----------------------------------------------------------------------------- +-- Program constants +----------------------------------------------------------------------------- +-- connection timeout in seconds +_M.TIMEOUT = 60 +-- default port for document retrieval +_M.PORT = 80 +-- user agent field sent in request +_M.USERAGENT = socket._VERSION + +-- Default settings for SSL +_M.SSLPORT = 443 +_M.SSLPROTOCOL = "tlsv1_2" +_M.SSLOPTIONS = "all" +_M.SSLVERIFY = "none" + + +----------------------------------------------------------------------------- +-- Reads MIME headers from a connection, unfolding where needed +----------------------------------------------------------------------------- +local function receiveheaders(sock, headers) + local line, name, value, err + headers = headers or {} + -- get first line + line, err = sock:receive() + if err then return nil, err end + -- headers go until a blank line is found + while line ~= "" do + -- get field-name and value + name, value = socket.skip(2, string.find(line, "^(.-):%s*(.*)")) + if not (name and value) then return nil, "malformed reponse headers" end + name = string.lower(name) + -- get next line (value might be folded) + line, err = sock:receive() + if err then return nil, err end + -- unfold any folded values + while string.find(line, "^%s") do + value = value .. line + line = sock:receive() + if err then return nil, err end + end + -- save pair in table + if headers[name] then headers[name] = headers[name] .. ", " .. value + else headers[name] = value end + end + return headers +end + +----------------------------------------------------------------------------- +-- Extra sources and sinks +----------------------------------------------------------------------------- +socket.sourcet["http-chunked"] = function(sock, headers) + return base.setmetatable({ + getfd = function() return sock:getfd() end, + dirty = function() return sock:dirty() end + }, { + __call = function() + -- get chunk size, skip extention + local line, err = sock:receive() + if err then return nil, err end + local size = base.tonumber(string.gsub(line, ";.*", ""), 16) + if not size then return nil, "invalid chunk size" end + -- was it the last chunk? + if size > 0 then + -- if not, get chunk and skip terminating CRLF + local chunk, err = sock:receive(size) + if chunk then sock:receive() end + return chunk, err + else + -- if it was, read trailers into headers table + headers, err = receiveheaders(sock, headers) + if not headers then return nil, err end + end + end + }) +end + +socket.sinkt["http-chunked"] = function(sock) + return base.setmetatable({ + getfd = function() return sock:getfd() end, + dirty = function() return sock:dirty() end + }, { + __call = function(self, chunk, err) + if not chunk then return sock:send("0\r\n\r\n") end + local size = string.format("%X\r\n", string.len(chunk)) + return sock:send(size .. chunk .. "\r\n") + end + }) +end + +----------------------------------------------------------------------------- +-- Low level HTTP API +----------------------------------------------------------------------------- +local metat = { __index = {} } + +function _M.open(reqt) + -- create socket with user connect function + local c = socket.try(reqt:create()) -- method call, passing reqt table as self! + local h = base.setmetatable({ c = c }, metat) + -- create finalized try + h.try = socket.newtry(function() h:close() end) + -- set timeout before connecting + h.try(c:settimeout(_M.TIMEOUT)) + h.try(c:connect(reqt.host, reqt.port or _M.PORT)) + -- here everything worked + return h +end + +function metat.__index:sendrequestline(method, uri) + local reqline = string.format("%s %s HTTP/1.1\r\n", method or "GET", uri) + return self.try(self.c:send(reqline)) +end + +function metat.__index:sendheaders(tosend) + local canonic = headers.canonic + local h = "\r\n" + for f, v in base.pairs(tosend) do + h = (canonic[f] or f) .. ": " .. v .. "\r\n" .. h + end + self.try(self.c:send(h)) + return 1 +end + +function metat.__index:sendbody(headers, source, step) + source = source or ltn12.source.empty() + step = step or ltn12.pump.step + -- if we don't know the size in advance, send chunked and hope for the best + local mode = "http-chunked" + if headers["content-length"] then mode = "keep-open" end + return self.try(ltn12.pump.all(source, socket.sink(mode, self.c), step)) +end + +function metat.__index:receivestatusline() + local status = self.try(self.c:receive(5)) + -- identify HTTP/0.9 responses, which do not contain a status line + -- this is just a heuristic, but is what the RFC recommends + if status ~= "HTTP/" then return nil, status end + -- otherwise proceed reading a status line + status = self.try(self.c:receive("*l", status)) + local code = socket.skip(2, string.find(status, "HTTP/%d*%.%d* (%d%d%d)")) + return self.try(base.tonumber(code), status) +end + +function metat.__index:receiveheaders() + return self.try(receiveheaders(self.c)) +end + +function metat.__index:receivebody(headers, sink, step) + sink = sink or ltn12.sink.null() + step = step or ltn12.pump.step + local length = base.tonumber(headers["content-length"]) + local t = headers["transfer-encoding"] -- shortcut + local mode = "default" -- connection close + if t and t ~= "identity" then mode = "http-chunked" + elseif base.tonumber(headers["content-length"]) then mode = "by-length" end + return self.try(ltn12.pump.all(socket.source(mode, self.c, length), + sink, step)) +end + +function metat.__index:receive09body(status, sink, step) + local source = ltn12.source.rewind(socket.source("until-closed", self.c)) + source(status) + return self.try(ltn12.pump.all(source, sink, step)) +end + +function metat.__index:close() + return self.c:close() +end + +----------------------------------------------------------------------------- +-- High level HTTP API +----------------------------------------------------------------------------- +local function adjusturi(reqt) + local u = reqt + -- if there is a proxy, we need the full url. otherwise, just a part. + if not reqt.proxy and not _M.PROXY then + u = { + path = socket.try(reqt.path, "invalid path 'nil'"), + params = reqt.params, + query = reqt.query, + fragment = reqt.fragment + } + end + return url.build(u) +end + +local function adjustproxy(reqt) + local proxy = reqt.proxy or _M.PROXY + if proxy then + proxy = url.parse(proxy) + return proxy.host, proxy.port or 3128 + else + return reqt.host, reqt.port + end +end + +local function adjustheaders(reqt) + -- default headers + local host = string.gsub(reqt.authority, "^.-@", "") + local lower = { + ["user-agent"] = _M.USERAGENT, + ["host"] = host, + ["connection"] = "close, TE", + ["te"] = "trailers" + } + -- if we have authentication information, pass it along + if reqt.user and reqt.password then + lower["authorization"] = + "Basic " .. (mime.b64(reqt.user .. ":" .. reqt.password)) + end + -- override with user headers + for i,v in base.pairs(reqt.headers or lower) do + lower[string.lower(i)] = v + end + return lower +end + +-- default url parts +local default = { + host = "", + port = _M.PORT, + path ="/", + scheme = "http" +} + +local function adjustrequest(reqt) + -- parse url if provided + local nreqt = reqt.url and url.parse(reqt.url, default) or {} + -- explicit components override url + for i,v in base.pairs(reqt) do nreqt[i] = v end + if nreqt.port == "" then nreqt.port = 80 end + socket.try(nreqt.host and nreqt.host ~= "", + "invalid host '" .. base.tostring(nreqt.host) .. "'") + -- compute uri if user hasn't overriden + nreqt.uri = reqt.uri or adjusturi(nreqt) + -- ajust host and port if there is a proxy + nreqt.host, nreqt.port = adjustproxy(nreqt) + -- adjust headers in request + nreqt.headers = adjustheaders(nreqt) + return nreqt +end + +local function shouldredirect(reqt, code, headers) + return headers.location and + string.gsub(headers.location, "%s", "") ~= "" and + (reqt.redirect ~= false) and + (code == 301 or code == 302 or code == 303 or code == 307) and + (not reqt.method or reqt.method == "GET" or reqt.method == "HEAD") + and (not reqt.nredirects or reqt.nredirects < 5) +end + +local function shouldreceivebody(reqt, code) + if reqt.method == "HEAD" then return nil end + if code == 204 or code == 304 then return nil end + if code >= 100 and code < 200 then return nil end + return 1 +end + +-- forward declarations +local trequest, tredirect + +--[[local]] function tredirect(reqt, location) + local result, code, headers, status = trequest { + -- the RFC says the redirect URL has to be absolute, but some + -- servers do not respect that + url = url.absolute(reqt.url, location), + source = reqt.source, + sink = reqt.sink, + headers = reqt.headers, + proxy = reqt.proxy, + nredirects = (reqt.nredirects or 0) + 1, + create = reqt.create + } + -- pass location header back as a hint we redirected + headers = headers or {} + headers.location = headers.location or location + return result, code, headers, status +end + +--[[local]] function trequest(reqt) + -- we loop until we get what we want, or + -- until we are sure there is no way to get it + local nreqt = adjustrequest(reqt) + local h = _M.open(nreqt) + -- send request line and headers + h:sendrequestline(nreqt.method, nreqt.uri) + h:sendheaders(nreqt.headers) + -- if there is a body, send it + if nreqt.source then + h:sendbody(nreqt.headers, nreqt.source, nreqt.step) + end + local code, status = h:receivestatusline() + -- if it is an HTTP/0.9 server, simply get the body and we are done + if not code then + h:receive09body(status, nreqt.sink, nreqt.step) + return 1, 200 + end + local headers + -- ignore any 100-continue messages + while code == 100 do + headers = h:receiveheaders() + code, status = h:receivestatusline() + end + headers = h:receiveheaders() + -- at this point we should have a honest reply from the server + -- we can't redirect if we already used the source, so we report the error + if shouldredirect(nreqt, code, headers) and not nreqt.source then + h:close() + return tredirect(reqt, headers.location) + end + -- here we are finally done + if shouldreceivebody(nreqt, code) then + h:receivebody(headers, nreqt.sink, nreqt.step) + end + h:close() + return 1, code, headers, status +end + +-- Return a function which performs the SSL/TLS connection. +local function tcp(params) + params = params or {} + -- Default settings + params.protocol = params.protocol or _M.SSLPROTOCOL + params.options = params.options or _M.SSLOPTIONS + params.verify = params.verify or _M.SSLVERIFY + params.mode = "client" -- Force client mode + -- upvalue to track https -> http redirection + local washttps = false + -- 'create' function for LuaSocket + return function (reqt) + local u = url.parse(reqt.url) + if (reqt.scheme or u.scheme) == "https" then + -- https, provide an ssl wrapped socket + local conn = copas.wrap(socket.tcp(), params) + -- insert https default port, overriding http port inserted by LuaSocket + if not u.port then + u.port = _M.SSLPORT + reqt.url = url.build(u) + reqt.port = _M.SSLPORT + end + washttps = true + return conn + else + -- regular http, needs just a socket... + if washttps and params.redirect ~= "all" then + try(nil, "Unallowed insecure redirect https to http") + end + return copas.wrap(socket.tcp()) + end + end +end + +-- parses a shorthand form into the advanced table form. +-- adds field `target` to the table. This will hold the return values. +_M.parseRequest = function(u, b) + local reqt = { + url = u, + target = {}, + } + reqt.sink = ltn12.sink.table(reqt.target) + if b then + reqt.source = ltn12.source.string(b) + reqt.headers = { + ["content-length"] = string.len(b), + ["content-type"] = "application/x-www-form-urlencoded" + } + reqt.method = "POST" + end + return reqt +end + +_M.request = socket.protect(function(reqt, body) + if base.type(reqt) == "string" then + reqt = _M.parseRequest(reqt, body) + local ok, code, headers, status = _M.request(reqt) + + if ok then + return table.concat(reqt.target), code, headers, status + else + return nil, code + end + else + reqt.create = reqt.create or tcp(reqt) + return trequest(reqt) + end +end) + +return _M diff --git a/copas/limit.lua b/copas/limit.lua new file mode 100644 index 0000000..e26d537 --- /dev/null +++ b/copas/limit.lua @@ -0,0 +1,99 @@ +-------------------------------------------------------------- +-- Limits resource usage while executing tasks. +-- Tasks added will be run in parallel, with a maximum of +-- simultaneous tasks to prevent consuming all/too many resources. +-- Every task added will immediately be scheduled (if there is room) +-- using the `wait` method one can wait for completion. + +local copas = require("copas") +local pack = table.pack or function(...) return {n=select('#',...),...} end +local unpack = function(t) return (table.unpack or unpack)(t, 1, t.n or #t) end + +local pcall = pcall +if _VERSION=="Lua 5.1" and not jit then -- obsolete: only for Lua 5.1 compatibility + pcall = require("coxpcall").pcall +end + +-- Add a task to the queue, returns the coroutine created +-- identical to `copas.addthread`. Can be called while the +-- set of tasks is executing. +local function add(self, task, ...) + local carg = pack(...) + local coro = copas.addthread(function() + copas.sleep(-1) -- go to sleep until being woken + local suc, err = pcall(task, unpack(carg)) -- start the task + self:removethread(coroutine.running()) -- dismiss ourselves + if not suc then error(err) end -- rethrow error + end) + table.insert(self.queue, coro) -- store in list + self:next() + return coro +end + +-- remove a task from the queue. Can be called while the +-- set of tasks is executing. Will NOT stop the task if +-- it is already running. +local function remove(self, coro) + self.queue[coro] = nil + if self.running[coro] then + -- it is in the already running set + self.running[coro] = nil + self.count = self.count - 1 + else + -- check the queue and remove if found + for i, item in ipairs(self.queue) do + if coro == item then + table.remove(self.queue, i) + break + end + end + end + self:next() +end + +-- schedules the next task (if any) for execution, signals completeness +local function nxt(self) + while self.count < self.maxt do + local coro = self.queue[1] + if not coro then break end -- queue is empty, so nothing to add + -- move it to running and restart the task + table.remove(self.queue, 1) + self.running[coro] = coro + self.count = self.count + 1 + copas.wakeup(coro) + end + if self.count == 0 and next(self.waiting) then + -- all tasks done, resume the waiting tasks so they can unblock/return + for coro in pairs(self.waiting) do + copas.wakeup(coro) + end + end +end + +-- Waits for the tasks. Yields until all are finished +local function wait(self) + if self.count == 0 then return end -- There's nothing to do... + local coro = coroutine.running() + -- now store this coroutine (so we know which to wakeup) and go to sleep + self.waiting[coro] = true + copas.sleep(-1) + self.waiting[coro] = nil +end + +-- creats a new tasksrunner, with maximum maxt simultaneous threads +local function new(maxt) + return { + maxt = maxt or 99999, -- max simultaneous tasks + count = 0, -- count of running tasks + queue = {}, -- tasks waiting (list/array) + running = {}, -- tasks currently running (indexed by coroutine) + waiting = {}, -- coroutines, waiting for all tasks being finished (indexed by coro) + addthread = add, + removethread = remove, + next = nxt, + wait = wait, + } +end + +return { new = new } + diff --git a/copas/limit.lua~ b/copas/limit.lua~ new file mode 100644 index 0000000..e26d537 --- /dev/null +++ b/copas/limit.lua~ @@ -0,0 +1,99 @@ +-------------------------------------------------------------- +-- Limits resource usage while executing tasks. +-- Tasks added will be run in parallel, with a maximum of +-- simultaneous tasks to prevent consuming all/too many resources. +-- Every task added will immediately be scheduled (if there is room) +-- using the `wait` method one can wait for completion. + +local copas = require("copas") +local pack = table.pack or function(...) return {n=select('#',...),...} end +local unpack = function(t) return (table.unpack or unpack)(t, 1, t.n or #t) end + +local pcall = pcall +if _VERSION=="Lua 5.1" and not jit then -- obsolete: only for Lua 5.1 compatibility + pcall = require("coxpcall").pcall +end + +-- Add a task to the queue, returns the coroutine created +-- identical to `copas.addthread`. Can be called while the +-- set of tasks is executing. +local function add(self, task, ...) + local carg = pack(...) + local coro = copas.addthread(function() + copas.sleep(-1) -- go to sleep until being woken + local suc, err = pcall(task, unpack(carg)) -- start the task + self:removethread(coroutine.running()) -- dismiss ourselves + if not suc then error(err) end -- rethrow error + end) + table.insert(self.queue, coro) -- store in list + self:next() + return coro +end + +-- remove a task from the queue. Can be called while the +-- set of tasks is executing. Will NOT stop the task if +-- it is already running. +local function remove(self, coro) + self.queue[coro] = nil + if self.running[coro] then + -- it is in the already running set + self.running[coro] = nil + self.count = self.count - 1 + else + -- check the queue and remove if found + for i, item in ipairs(self.queue) do + if coro == item then + table.remove(self.queue, i) + break + end + end + end + self:next() +end + +-- schedules the next task (if any) for execution, signals completeness +local function nxt(self) + while self.count < self.maxt do + local coro = self.queue[1] + if not coro then break end -- queue is empty, so nothing to add + -- move it to running and restart the task + table.remove(self.queue, 1) + self.running[coro] = coro + self.count = self.count + 1 + copas.wakeup(coro) + end + if self.count == 0 and next(self.waiting) then + -- all tasks done, resume the waiting tasks so they can unblock/return + for coro in pairs(self.waiting) do + copas.wakeup(coro) + end + end +end + +-- Waits for the tasks. Yields until all are finished +local function wait(self) + if self.count == 0 then return end -- There's nothing to do... + local coro = coroutine.running() + -- now store this coroutine (so we know which to wakeup) and go to sleep + self.waiting[coro] = true + copas.sleep(-1) + self.waiting[coro] = nil +end + +-- creats a new tasksrunner, with maximum maxt simultaneous threads +local function new(maxt) + return { + maxt = maxt or 99999, -- max simultaneous tasks + count = 0, -- count of running tasks + queue = {}, -- tasks waiting (list/array) + running = {}, -- tasks currently running (indexed by coroutine) + waiting = {}, -- coroutines, waiting for all tasks being finished (indexed by coro) + addthread = add, + removethread = remove, + next = nxt, + wait = wait, + } +end + +return { new = new } + diff --git a/copas/limit.lua~~ b/copas/limit.lua~~ new file mode 100644 index 0000000..e26d537 --- /dev/null +++ b/copas/limit.lua~~ @@ -0,0 +1,99 @@ +-------------------------------------------------------------- +-- Limits resource usage while executing tasks. +-- Tasks added will be run in parallel, with a maximum of +-- simultaneous tasks to prevent consuming all/too many resources. +-- Every task added will immediately be scheduled (if there is room) +-- using the `wait` method one can wait for completion. + +local copas = require("copas") +local pack = table.pack or function(...) return {n=select('#',...),...} end +local unpack = function(t) return (table.unpack or unpack)(t, 1, t.n or #t) end + +local pcall = pcall +if _VERSION=="Lua 5.1" and not jit then -- obsolete: only for Lua 5.1 compatibility + pcall = require("coxpcall").pcall +end + +-- Add a task to the queue, returns the coroutine created +-- identical to `copas.addthread`. Can be called while the +-- set of tasks is executing. +local function add(self, task, ...) + local carg = pack(...) + local coro = copas.addthread(function() + copas.sleep(-1) -- go to sleep until being woken + local suc, err = pcall(task, unpack(carg)) -- start the task + self:removethread(coroutine.running()) -- dismiss ourselves + if not suc then error(err) end -- rethrow error + end) + table.insert(self.queue, coro) -- store in list + self:next() + return coro +end + +-- remove a task from the queue. Can be called while the +-- set of tasks is executing. Will NOT stop the task if +-- it is already running. +local function remove(self, coro) + self.queue[coro] = nil + if self.running[coro] then + -- it is in the already running set + self.running[coro] = nil + self.count = self.count - 1 + else + -- check the queue and remove if found + for i, item in ipairs(self.queue) do + if coro == item then + table.remove(self.queue, i) + break + end + end + end + self:next() +end + +-- schedules the next task (if any) for execution, signals completeness +local function nxt(self) + while self.count < self.maxt do + local coro = self.queue[1] + if not coro then break end -- queue is empty, so nothing to add + -- move it to running and restart the task + table.remove(self.queue, 1) + self.running[coro] = coro + self.count = self.count + 1 + copas.wakeup(coro) + end + if self.count == 0 and next(self.waiting) then + -- all tasks done, resume the waiting tasks so they can unblock/return + for coro in pairs(self.waiting) do + copas.wakeup(coro) + end + end +end + +-- Waits for the tasks. Yields until all are finished +local function wait(self) + if self.count == 0 then return end -- There's nothing to do... + local coro = coroutine.running() + -- now store this coroutine (so we know which to wakeup) and go to sleep + self.waiting[coro] = true + copas.sleep(-1) + self.waiting[coro] = nil +end + +-- creats a new tasksrunner, with maximum maxt simultaneous threads +local function new(maxt) + return { + maxt = maxt or 99999, -- max simultaneous tasks + count = 0, -- count of running tasks + queue = {}, -- tasks waiting (list/array) + running = {}, -- tasks currently running (indexed by coroutine) + waiting = {}, -- coroutines, waiting for all tasks being finished (indexed by coro) + addthread = add, + removethread = remove, + next = nxt, + wait = wait, + } +end + +return { new = new } + diff --git a/copas/lock.lua b/copas/lock.lua new file mode 100644 index 0000000..7b33c0d --- /dev/null +++ b/copas/lock.lua @@ -0,0 +1,191 @@ +local copas = require("copas") +local gettime = require("socket").gettime + +local DEFAULT_TIMEOUT = 10 + +local lock = {} +lock.__index = lock + + +-- registry, locks indexed by the coroutines using them. +local registry = setmetatable({}, { __mode="kv" }) + + + +--- Creates a new lock. +-- @param seconds (optional) default timeout in seconds when acquiring the lock (defaults to 10), +-- set to `math.huge` to have no timeout. +-- @param not_reentrant (optional) if truthy the lock will not allow a coroutine to grab the same lock multiple times +-- @return the lock object +function lock.new(seconds, not_reentrant) + local timeout = tonumber(seconds or DEFAULT_TIMEOUT) or -1 + if timeout < 0 then + error("expected timeout (1st argument) to be a number greater than or equal to 0, got: " .. tostring(seconds), 2) + end + return setmetatable({ + timeout = timeout, + not_reentrant = not_reentrant, + queue = {}, + q_tip = 0, -- index of the first in line waiting + q_tail = 0, -- index where the next one will be inserted + owner = nil, -- coroutine holding lock currently + call_count = nil, -- recursion call count + errors = setmetatable({}, { __mode = "k" }), -- error indexed by coroutine + }, lock) +end + + + +do + local destroyed_func = function() + return nil, "destroyed" + end + + local destroyed_lock_mt = { + __index = function() + return destroyed_func + end + } + + --- destroy a lock. + -- Releases all waiting threads with `nil+"destroyed"` + function lock:destroy() + --print("destroying ",self) + for i = self.q_tip, self.q_tail do + local co = self.queue[i] + self.queue[i] = nil + + if co then + self.errors[co] = "destroyed" + --print("marked destroyed ", co) + copas.wakeup(co) + end + end + + if self.owner then + self.errors[self.owner] = "destroyed" + --print("marked destroyed ", co) + end + self.queue = {} + self.q_tip = 0 + self.q_tail = 0 + self.destroyed = true + + setmetatable(self, destroyed_lock_mt) + return true + end +end + + +local function timeout_handler(co) + local self = registry[co] + if not self then + return + end + + for i = self.q_tip, self.q_tail do + if co == self.queue[i] then + self.queue[i] = nil + self.errors[co] = "timeout" + --print("marked timeout ", co) + copas.wakeup(co) + return + end + end + -- if we get here, we own it currently, or we finished it by now, or + -- the lock was destroyed. Anyway, nothing to do here... +end + + +--- Acquires the lock. +-- If the lock is owned by another thread, this will yield control, until the +-- lock becomes available, or it times out. +-- If `timeout == 0` then it will immediately return (without yielding). +-- @param timeout (optional) timeout in seconds, defaults to the timeout passed to `new` (use `math.huge` to have no timeout). +-- @return wait-time on success, or nil+error+wait_time on failure. Errors can be "timeout", "destroyed", or "lock is not re-entrant" +function lock:get(timeout) + local co = coroutine.running() + local start_time + + -- is the lock already taken? + if self.owner then + -- are we re-entering? + if co == self.owner and not self.not_reentrant then + self.call_count = self.call_count + 1 + return 0 + end + + self.queue[self.q_tail] = co + self.q_tail = self.q_tail + 1 + timeout = timeout or self.timeout + if timeout == 0 then + return nil, "timeout", 0 + end + + -- set up timeout + registry[co] = self + copas.timeout(timeout, timeout_handler) + + start_time = gettime() + copas.pauseforever() + + local err = self.errors[co] + self.errors[co] = nil + registry[co] = nil + + --print("released ", co, err) + if err ~= "timeout" then + copas.timeout(0) + end + if err then + return nil, err, gettime() - start_time + end + end + + -- it's ours to have + self.owner = co + self.call_count = 1 + return start_time and (gettime() - start_time) or 0 +end + + +--- Releases the lock currently held. +-- Releasing a lock that is not owned by the current co-routine will return +-- an error. +-- returns true, or nil+err on an error +function lock:release() + local co = coroutine.running() + + if co ~= self.owner then + return nil, "cannot release a lock not owned" + end + + self.call_count = self.call_count - 1 + if self.call_count > 0 then + -- same coro is still holding it + return true + end + + -- need a loop, since individual coroutines might have been removed + -- so there might be holes + while self.q_tip < self.q_tail do + local next_up = self.queue[self.q_tip] + if next_up then + self.owner = next_up + self.queue[self.q_tip] = nil + self.q_tip = self.q_tip + 1 + copas.wakeup(next_up) + return true + end + self.q_tip = self.q_tip + 1 + end + -- queue is empty, reset pointers + self.owner = nil + self.q_tip = 0 + self.q_tail = 0 + return true +end + + + +return lock diff --git a/copas/queue.lua b/copas/queue.lua new file mode 100644 index 0000000..01cd53c --- /dev/null +++ b/copas/queue.lua @@ -0,0 +1,191 @@ +local copas = require "copas" +local Sema = copas.semaphore +local Lock = copas.lock + + +local Queue = {} +Queue.__index = Queue + + +local new_name do + local count = 0 + + function new_name() + count = count + 1 + return "copas_queue_" .. count + end +end + + +-- Creates a new Queue instance +function Queue.new(opts) + opts = opts or {} + local self = {} + setmetatable(self, Queue) + self.name = opts.name or new_name() + self.sema = Sema.new(10^9) + self.head = 1 + self.tail = 1 + self.list = {} + self.workers = setmetatable({}, { __mode = "k" }) + self.stopping = false + self.worker_id = 0 + return self +end + + +-- Pushes an item in the queue (can be 'nil') +-- returns true, or nil+err ("stopping", or "destroyed") +function Queue:push(item) + if self.stopping then + return nil, "stopping" + end + self.list[self.head] = item + self.head = self.head + 1 + self.sema:give() + return true +end + + +-- Pops and item from the queue. If there are no items in the queue it will yield +-- until there are or a timeout happens (exception is when `timeout == 0`, then it will +-- not yield but return immediately). If the timeout is `math.huge` it will wait forever. +-- Returns item, or nil+err ("timeout", or "destroyed") +function Queue:pop(timeout) + local ok, err = self.sema:take(1, timeout) + if not ok then + return ok, err + end + + local item = self.list[self.tail] + self.list[self.tail] = nil + self.tail = self.tail + 1 + + if self.tail == self.head then + -- reset queue + self.list = {} + self.tail = 1 + self.head = 1 + if self.stopping then + -- we're stopping and last item being returned, so we're done + self:destroy() + end + end + return item +end + + +-- return the number of items left in the queue +function Queue:get_size() + return self.head - self.tail +end + + +-- instructs the queue to stop. Will not accept any more 'push' calls. +-- will autocall 'destroy' when the queue is empty. +-- returns immediately. See `finish` +function Queue:stop() + if not self.stopping then + self.stopping = true + self.lock = Lock.new(nil, true) + self.lock:get() -- close the lock + if self:get_size() == 0 then + -- queue is already empty, so "pop" function cannot call destroy on next + -- pop, so destroy now. + self:destroy() + end + end + return true +end + + +-- Finishes a queue. Calls stop and then waits for the queue to run empty (and be +-- destroyed) before returning. returns true or nil+err ("timeout", or "destroyed") +-- Parameter no_destroy_on_timeout indicates if the queue is not to be forcefully +-- destroyed on a timeout. +function Queue:finish(timeout, no_destroy_on_timeout) + self:stop() + local _, err = self.lock:get(timeout) + -- the lock never gets released, only destroyed, so we have to check the error string + if err == "timeout" then + if not no_destroy_on_timeout then + self:destroy() + end + return nil, err + end + return true +end + + +do + local destroyed_func = function() + return nil, "destroyed" + end + + local destroyed_queue_mt = { + __index = function() + return destroyed_func + end + } + + -- destroys a queue immediately. Abandons what is left in the queue. + -- Releases all waiting threads with `nil+"destroyed"` + function Queue:destroy() + if self.lock then + self.lock:destroy() + end + self.sema:destroy() + setmetatable(self, destroyed_queue_mt) + + -- clear anything left in the queue + for key in pairs(self.list) do + self.list[key] = nil + end + + return true + end +end + + +-- adds a worker that will handle whatever is passed into the queue. Can be called +-- multiple times to add more workers. +-- The threads automatically exit when the queue is destroyed. +-- worker function signature: `function(item)` (Note: worker functions run +-- unprotected, so wrap code in an (x)pcall if errors are expected, otherwise the +-- worker will exit on an error, and queue handling will stop) +-- Returns the coroutine added. +function Queue:add_worker(worker) + assert(type(worker) == "function", "expected worker to be a function") + local coro + + self.worker_id = self.worker_id + 1 + local worker_name = self.name .. ":worker_" .. self.worker_id + + coro = copas.addnamedthread(worker_name, function() + while true do + local item, err = self:pop(math.huge) -- wait forever + if err then + break -- queue destroyed, exit + end + worker(item) -- TODO: wrap in errorhandling + end + self.workers[coro] = nil + end) + + self.workers[coro] = true + return coro +end + +-- returns a list/array of current workers (coroutines) handling the queue. +-- (only the workers added by `add_worker`, and still active, will be in this list) +function Queue:get_workers() + local lst = {} + for coro in pairs(self.workers) do + if coroutine.status(coro) ~= "dead" then + lst[#lst+1] = coro + end + end + return lst +end + +return Queue diff --git a/copas/semaphore.lua b/copas/semaphore.lua new file mode 100644 index 0000000..0f4fda3 --- /dev/null +++ b/copas/semaphore.lua @@ -0,0 +1,202 @@ +local copas = require("copas") + +local DEFAULT_TIMEOUT = 10 + +local semaphore = {} +semaphore.__index = semaphore + + +-- registry, semaphore indexed by the coroutines using them. +local registry = setmetatable({}, { __mode="kv" }) + + +-- create a new semaphore +-- @param max maximum number of resources the semaphore can hold (this maximum does NOT include resources that have been given but not yet returned). +-- @param start (optional, default 0) the initial resources available +-- @param seconds (optional, default 10) default semaphore timeout in seconds, or `math.huge` to have no timeout. +function semaphore.new(max, start, seconds) + local timeout = tonumber(seconds or DEFAULT_TIMEOUT) or -1 + if timeout < 0 then + error("expected timeout (2nd argument) to be a number greater than or equal to 0, got: " .. tostring(seconds), 2) + end + if type(max) ~= "number" or max < 1 then + error("expected max resources (1st argument) to be a number greater than 0, got: " .. tostring(max), 2) + end + + local self = setmetatable({ + count = start or 0, + max = max, + timeout = timeout, + q_tip = 1, -- position of next entry waiting + q_tail = 1, -- position where next one will be inserted + queue = {}, + to_flags = setmetatable({}, { __mode = "k" }), -- timeout flags indexed by coroutine + }, semaphore) + + return self +end + + +do + local destroyed_func = function() + return nil, "destroyed" + end + + local destroyed_semaphore_mt = { + __index = function() + return destroyed_func + end + } + + -- destroy a semaphore. + -- Releases all waiting threads with `nil+"destroyed"` + function semaphore:destroy() + self:give(math.huge) + self.destroyed = true + setmetatable(self, destroyed_semaphore_mt) + return true + end +end + + +-- Gives resources. +-- @param given (optional, default 1) number of resources to return. If more +-- than the maximum are returned then it will be capped at the maximum and +-- error "too many" will be returned. +function semaphore:give(given) + local err + given = given or 1 + local count = self.count + given + --print("now at",count, ", after +"..given) + if count > self.max then + count = self.max + err = "too many" + end + + while self.q_tip < self.q_tail do + local i = self.q_tip + local nxt = self.queue[i] -- there can be holes, so nxt might be nil + if not nxt then + self.q_tip = i + 1 + else + if count >= nxt.requested then + -- release it + self.queue[i] = nil + self.to_flags[nxt.co] = nil + count = count - nxt.requested + self.q_tip = i + 1 + copas.wakeup(nxt.co) + nxt.co = nil + else + break -- we ran out of resources + end + end + end + + if self.q_tip == self.q_tail then -- reset queue + self.queue = {} + self.q_tip = 1 + self.q_tail = 1 + end + + self.count = count + if err then + return nil, err + end + return true +end + + + +local function timeout_handler(co) + local self = registry[co] + --print("checking timeout ", co) + if not self then + return + end + + for i = self.q_tip, self.q_tail do + local item = self.queue[i] + if item and co == item.co then + self.queue[i] = nil + self.to_flags[co] = true + --print("marked timeout ", co) + copas.wakeup(co) + return + end + end + -- nothing to do here... +end + + +-- Requests resources from the semaphore. +-- Waits if there are not enough resources available before returning. +-- @param requested (optional, default 1) the number of resources requested +-- @param timeout (optional, defaults to semaphore timeout) timeout in +-- seconds. If 0 it will either succeed or return immediately with error "timeout". +-- If `math.huge` it will wait forever. +-- @return true, or nil+"destroyed" +function semaphore:take(requested, timeout) + requested = requested or 1 + if self.q_tail == 1 and self.count >= requested then + -- nobody is waiting before us, and there is enough in store + self.count = self.count - requested + return true + end + + if requested > self.max then + return nil, "too many" + end + + local to = timeout or self.timeout + if to == 0 then + return nil, "timeout" + end + + -- get in line + local co = coroutine.running() + self.to_flags[co] = nil + registry[co] = self + copas.timeout(to, timeout_handler) + + self.queue[self.q_tail] = { + co = co, + requested = requested, + --timeout = nil, -- flag indicating timeout + } + self.q_tail = self.q_tail + 1 + + copas.pauseforever() -- block until woken + registry[co] = nil + + if self.to_flags[co] then + -- a timeout happened + self.to_flags[co] = nil + return nil, "timeout" + end + + copas.timeout(0) + + if self.destroyed then + return nil, "destroyed" + end + + return true +end + +-- returns current available resources +function semaphore:get_count() + return self.count +end + +-- returns total shortage for requested resources +function semaphore:get_wait() + local wait = 0 + for i = self.q_tip, self.q_tail - 1 do + wait = wait + ((self.queue[i] or {}).requested or 0) + end + return wait - self.count +end + + +return semaphore diff --git a/copas/smtp.lua b/copas/smtp.lua new file mode 100644 index 0000000..0d175eb --- /dev/null +++ b/copas/smtp.lua @@ -0,0 +1,33 @@ +------------------------------------------------------------------- +-- identical to the socket.smtp module except that it uses +-- async wrapped Copas sockets + +local copas = require("copas") +local smtp = require("socket.smtp") + +local create = function() return copas.wrap(socket.tcp()) end +local forwards = { -- setting these will be forwarded to the original smtp module + PORT = true, + SERVER = true, + TIMEOUT = true, + DOMAIN = true, + TIMEZONE = true +} + +copas.smtp = setmetatable({}, { + -- use original module as metatable, to lookup constants like socket.SERVER, etc. + __index = smtp, + -- Setting constants is forwarded to the luasocket.smtp module. + __newindex = function(self, key, value) + if forwards[key] then smtp[key] = value return end + return rawset(self, key, value) + end, + }) +local _M = copas.smtp + +_M.send = function(mailt) + mailt.create = mailt.create or create + return smtp.send(mailt) +end + +return _M \ No newline at end of file diff --git a/copas/smtp.lua~ b/copas/smtp.lua~ new file mode 100644 index 0000000..0d175eb --- /dev/null +++ b/copas/smtp.lua~ @@ -0,0 +1,33 @@ +------------------------------------------------------------------- +-- identical to the socket.smtp module except that it uses +-- async wrapped Copas sockets + +local copas = require("copas") +local smtp = require("socket.smtp") + +local create = function() return copas.wrap(socket.tcp()) end +local forwards = { -- setting these will be forwarded to the original smtp module + PORT = true, + SERVER = true, + TIMEOUT = true, + DOMAIN = true, + TIMEZONE = true +} + +copas.smtp = setmetatable({}, { + -- use original module as metatable, to lookup constants like socket.SERVER, etc. + __index = smtp, + -- Setting constants is forwarded to the luasocket.smtp module. + __newindex = function(self, key, value) + if forwards[key] then smtp[key] = value return end + return rawset(self, key, value) + end, + }) +local _M = copas.smtp + +_M.send = function(mailt) + mailt.create = mailt.create or create + return smtp.send(mailt) +end + +return _M \ No newline at end of file diff --git a/copas/smtp.lua~~ b/copas/smtp.lua~~ new file mode 100644 index 0000000..0d175eb --- /dev/null +++ b/copas/smtp.lua~~ @@ -0,0 +1,33 @@ +------------------------------------------------------------------- +-- identical to the socket.smtp module except that it uses +-- async wrapped Copas sockets + +local copas = require("copas") +local smtp = require("socket.smtp") + +local create = function() return copas.wrap(socket.tcp()) end +local forwards = { -- setting these will be forwarded to the original smtp module + PORT = true, + SERVER = true, + TIMEOUT = true, + DOMAIN = true, + TIMEZONE = true +} + +copas.smtp = setmetatable({}, { + -- use original module as metatable, to lookup constants like socket.SERVER, etc. + __index = smtp, + -- Setting constants is forwarded to the luasocket.smtp module. + __newindex = function(self, key, value) + if forwards[key] then smtp[key] = value return end + return rawset(self, key, value) + end, + }) +local _M = copas.smtp + +_M.send = function(mailt) + mailt.create = mailt.create or create + return smtp.send(mailt) +end + +return _M \ No newline at end of file diff --git a/copas/timer.lua b/copas/timer.lua new file mode 100644 index 0000000..09041ea --- /dev/null +++ b/copas/timer.lua @@ -0,0 +1,130 @@ +local copas = require("copas") + +local xpcall = xpcall +local coroutine_running = coroutine.running + +if _VERSION=="Lua 5.1" and not jit then -- obsolete: only for Lua 5.1 compatibility + xpcall = require("coxpcall").xpcall + coroutine_running = require("coxpcall").running +end + + +local timer = {} +timer.__index = timer + + +local new_name do + local count = 0 + + function new_name() + count = count + 1 + return "copas_timer_" .. count + end +end + + +do + local function expire_func(self, initial_delay) + if self.errorhandler then + copas.seterrorhandler(self.errorhandler) + end + copas.pause(initial_delay) + while true do + if not self.cancelled then + if not self.recurring then + -- non-recurring timer + self.cancelled = true + self.co = nil + + self:callback(self.params) + return + + else + -- recurring timer + self:callback(self.params) + end + end + + if self.cancelled then + -- clean up and exit the thread + self.co = nil + self.cancelled = true + return + end + + copas.pause(self.delay) + end + end + + + --- Arms the timer object. + -- @param initial_delay (optional) the first delay to use, if not provided uses the timer delay + -- @return timer object, nil+error, or throws an error on bad input + function timer:arm(initial_delay) + assert(initial_delay == nil or initial_delay >= 0, "delay must be greater than or equal to 0") + if self.co then + return nil, "already armed" + end + + self.cancelled = false + self.co = copas.addnamedthread(self.name, expire_func, self, initial_delay or self.delay) + return self + end +end + + + +--- Cancels a running timer. +-- @return timer object, or nil+error +function timer:cancel() + if not self.co then + return nil, "not armed" + end + + if self.cancelled then + return nil, "already cancelled" + end + + self.cancelled = true + copas.wakeup(self.co) -- resume asap + copas.removethread(self.co) -- will immediately drop the thread upon resuming + self.co = nil + return self +end + + +do + -- xpcall error handler that forwards to the copas errorhandler + local ehandler = function(err_obj) + return copas.geterrorhandler()(err_obj, coroutine_running(), nil) + end + + + --- Creates a new timer object. + -- Note: the callback signature is: `function(timer_obj, params)`. + -- @param opts (table) `opts.delay` timer delay in seconds, `opts.callback` function to execute, `opts.recurring` boolean + -- `opts.params` (optional) this value will be passed to the timer callback, `opts.initial_delay` (optional) the first delay to use, defaults to `delay`. + -- @return timer object, or throws an error on bad input + function timer.new(opts) + assert(opts.delay or -1 >= 0, "delay must be greater than or equal to 0") + assert(type(opts.callback) == "function", "expected callback to be a function") + + local callback = function(timer_obj, params) + xpcall(opts.callback, ehandler, timer_obj, params) + end + + return setmetatable({ + name = opts.name or new_name(), + delay = opts.delay, + callback = callback, + recurring = not not opts.recurring, + params = opts.params, + cancelled = false, + errorhandler = opts.errorhandler, + }, timer):arm(opts.initial_delay) + end +end + + + +return timer diff --git a/ltn12.lua b/ltn12.lua new file mode 100644 index 0000000..5b10f56 --- /dev/null +++ b/ltn12.lua @@ -0,0 +1,298 @@ +----------------------------------------------------------------------------- +-- LTN12 - Filters, sources, sinks and pumps. +-- LuaSocket toolkit. +-- Author: Diego Nehab +----------------------------------------------------------------------------- + +----------------------------------------------------------------------------- +-- Declare module +----------------------------------------------------------------------------- +local string = require("string") +local table = require("table") +local base = _G +local _M = {} +if module then -- heuristic for exporting a global package table + ltn12 = _M +end +local filter,source,sink,pump = {},{},{},{} + +_M.filter = filter +_M.source = source +_M.sink = sink +_M.pump = pump + +-- 2048 seems to be better in windows... +_M.BLOCKSIZE = 2048 +_M._VERSION = "LTN12 1.0.3" + +----------------------------------------------------------------------------- +-- Filter stuff +----------------------------------------------------------------------------- +-- returns a high level filter that cycles a low-level filter +function filter.cycle(low, ctx, extra) + base.assert(low) + return function(chunk) + local ret + ret, ctx = low(ctx, chunk, extra) + return ret + end +end + +-- chains a bunch of filters together +-- (thanks to Wim Couwenberg) +function filter.chain(...) + local arg = {...} + local n = select('#',...) + local top, index = 1, 1 + local retry = "" + return function(chunk) + retry = chunk and retry + while true do + if index == top then + chunk = arg[index](chunk) + if chunk == "" or top == n then return chunk + elseif chunk then index = index + 1 + else + top = top+1 + index = top + end + else + chunk = arg[index](chunk or "") + if chunk == "" then + index = index - 1 + chunk = retry + elseif chunk then + if index == n then return chunk + else index = index + 1 end + else base.error("filter returned inappropriate nil") end + end + end + end +end + +----------------------------------------------------------------------------- +-- Source stuff +----------------------------------------------------------------------------- +-- create an empty source +local function empty() + return nil +end + +function source.empty() + return empty +end + +-- returns a source that just outputs an error +function source.error(err) + return function() + return nil, err + end +end + +-- creates a file source +function source.file(handle, io_err) + if handle then + return function() + local chunk = handle:read(_M.BLOCKSIZE) + if not chunk then handle:close() end + return chunk + end + else return source.error(io_err or "unable to open file") end +end + +-- turns a fancy source into a simple source +function source.simplify(src) + base.assert(src) + return function() + local chunk, err_or_new = src() + src = err_or_new or src + if not chunk then return nil, err_or_new + else return chunk end + end +end + +-- creates string source +function source.string(s) + if s then + local i = 1 + return function() + local chunk = string.sub(s, i, i+_M.BLOCKSIZE-1) + i = i + _M.BLOCKSIZE + if chunk ~= "" then return chunk + else return nil end + end + else return source.empty() end +end + +-- creates rewindable source +function source.rewind(src) + base.assert(src) + local t = {} + return function(chunk) + if not chunk then + chunk = table.remove(t) + if not chunk then return src() + else return chunk end + else + table.insert(t, chunk) + end + end +end + +function source.chain(src, f) + base.assert(src and f) + local last_in, last_out = "", "" + local state = "feeding" + local err + return function() + if not last_out then + base.error('source is empty!', 2) + end + while true do + if state == "feeding" then + last_in, err = src() + if err then return nil, err end + last_out = f(last_in) + if not last_out then + if last_in then + base.error('filter returned inappropriate nil') + else + return nil + end + elseif last_out ~= "" then + state = "eating" + if last_in then last_in = "" end + return last_out + end + else + last_out = f(last_in) + if last_out == "" then + if last_in == "" then + state = "feeding" + else + base.error('filter returned ""') + end + elseif not last_out then + if last_in then + base.error('filter returned inappropriate nil') + else + return nil + end + else + return last_out + end + end + end + end +end + +-- creates a source that produces contents of several sources, one after the +-- other, as if they were concatenated +-- (thanks to Wim Couwenberg) +function source.cat(...) + local arg = {...} + local src = table.remove(arg, 1) + return function() + while src do + local chunk, err = src() + if chunk then return chunk end + if err then return nil, err end + src = table.remove(arg, 1) + end + end +end + +----------------------------------------------------------------------------- +-- Sink stuff +----------------------------------------------------------------------------- +-- creates a sink that stores into a table +function sink.table(t) + t = t or {} + local f = function(chunk, err) + if chunk then table.insert(t, chunk) end + return 1 + end + return f, t +end + +-- turns a fancy sink into a simple sink +function sink.simplify(snk) + base.assert(snk) + return function(chunk, err) + local ret, err_or_new = snk(chunk, err) + if not ret then return nil, err_or_new end + snk = err_or_new or snk + return 1 + end +end + +-- creates a file sink +function sink.file(handle, io_err) + if handle then + return function(chunk, err) + if not chunk then + handle:close() + return 1 + else return handle:write(chunk) end + end + else return sink.error(io_err or "unable to open file") end +end + +-- creates a sink that discards data +local function null() + return 1 +end + +function sink.null() + return null +end + +-- creates a sink that just returns an error +function sink.error(err) + return function() + return nil, err + end +end + +-- chains a sink with a filter +function sink.chain(f, snk) + base.assert(f and snk) + return function(chunk, err) + if chunk ~= "" then + local filtered = f(chunk) + local done = chunk and "" + while true do + local ret, snkerr = snk(filtered, err) + if not ret then return nil, snkerr end + if filtered == done then return 1 end + filtered = f(done) + end + else return 1 end + end +end + +----------------------------------------------------------------------------- +-- Pump stuff +----------------------------------------------------------------------------- +-- pumps one chunk from the source to the sink +function pump.step(src, snk) + local chunk, src_err = src() + local ret, snk_err = snk(chunk, src_err) + if chunk and ret then return 1 + else return nil, src_err or snk_err end +end + +-- pumps all data from a source to a sink, using a step function +function pump.all(src, snk, step) + base.assert(src and snk) + step = step or pump.step + while true do + local ret, err = step(src, snk) + if not ret then + if err then return nil, err + else return 1 end + end + end +end + +return _M diff --git a/mime.lua b/mime.lua new file mode 100644 index 0000000..642cd9c --- /dev/null +++ b/mime.lua @@ -0,0 +1,90 @@ +----------------------------------------------------------------------------- +-- MIME support for the Lua language. +-- Author: Diego Nehab +-- Conforming to RFCs 2045-2049 +----------------------------------------------------------------------------- + +----------------------------------------------------------------------------- +-- Declare module and import dependencies +----------------------------------------------------------------------------- +local base = _G +local ltn12 = require("ltn12") +local mime = require("mime.core") +local io = require("io") +local string = require("string") +local _M = mime + +-- encode, decode and wrap algorithm tables +local encodet, decodet, wrapt = {},{},{} + +_M.encodet = encodet +_M.decodet = decodet +_M.wrapt = wrapt + +-- creates a function that chooses a filter by name from a given table +local function choose(table) + return function(name, opt1, opt2) + if base.type(name) ~= "string" then + name, opt1, opt2 = "default", name, opt1 + end + local f = table[name or "nil"] + if not f then + base.error("unknown key (" .. base.tostring(name) .. ")", 3) + else return f(opt1, opt2) end + end +end + +-- define the encoding filters +encodet['base64'] = function() + return ltn12.filter.cycle(_M.b64, "") +end + +encodet['quoted-printable'] = function(mode) + return ltn12.filter.cycle(_M.qp, "", + (mode == "binary") and "=0D=0A" or "\r\n") +end + +-- define the decoding filters +decodet['base64'] = function() + return ltn12.filter.cycle(_M.unb64, "") +end + +decodet['quoted-printable'] = function() + return ltn12.filter.cycle(_M.unqp, "") +end + +local function format(chunk) + if chunk then + if chunk == "" then return "''" + else return string.len(chunk) end + else return "nil" end +end + +-- define the line-wrap filters +wrapt['text'] = function(length) + length = length or 76 + return ltn12.filter.cycle(_M.wrp, length, length) +end +wrapt['base64'] = wrapt['text'] +wrapt['default'] = wrapt['text'] + +wrapt['quoted-printable'] = function() + return ltn12.filter.cycle(_M.qpwrp, 76, 76) +end + +-- function that choose the encoding, decoding or wrap algorithm +_M.encode = choose(encodet) +_M.decode = choose(decodet) +_M.wrap = choose(wrapt) + +-- define the end-of-line normalization filter +function _M.normalize(marker) + return ltn12.filter.cycle(_M.eol, 0, marker) +end + +-- high level stuffing filter +function _M.stuff() + return ltn12.filter.cycle(_M.dot, 2) +end + +return _M \ No newline at end of file diff --git a/mime/core.dll b/mime/core.dll new file mode 100644 index 0000000..c35e86d Binary files /dev/null and b/mime/core.dll differ diff --git a/samp/events.lua b/samp/events.lua new file mode 100644 index 0000000..0764f3a --- /dev/null +++ b/samp/events.lua @@ -0,0 +1,292 @@ +-- This file is part of the SAMP.Lua project. +-- Licensed under the MIT License. +-- Copyright (c) 2016, FYP @ BlastHack Team +-- https://github.com/THE-FYP/SAMP.Lua + +local raknet = require 'samp.raknet' +local events = require 'samp.events.core' +local utils = require 'samp.events.utils' +local handler = require 'samp.events.handlers' + require 'samp.events.extra_types' +local RPC = raknet.RPC +local PACKET = raknet.PACKET +local OUTCOMING_RPCS = events.INTERFACE.OUTCOMING_RPCS +local OUTCOMING_PACKETS = events.INTERFACE.OUTCOMING_PACKETS +local INCOMING_RPCS = events.INTERFACE.INCOMING_RPCS +local INCOMING_PACKETS = events.INTERFACE.INCOMING_PACKETS + +-- Outgoing rpcs +OUTCOMING_RPCS[RPC.ENTERVEHICLE] = {'onSendEnterVehicle', {vehicleId = 'uint16'}, {passenger = 'bool8'}} +OUTCOMING_RPCS[RPC.CLICKPLAYER] = {'onSendClickPlayer', {playerId = 'uint16'}, {source = 'uint8'}} +OUTCOMING_RPCS[RPC.CLIENTJOIN] = {'onSendClientJoin', {version = 'int32'}, {mod = 'uint8'}, {nickname = 'string8'}, {challengeResponse = 'int32'}, {joinAuthKey = 'string8'}, {clientVer = 'string8'}, {challengeResponse2 = 'int32'}} +--OUTCOMING_RPCS[RPC.SELECTOBJECT] = {'onSendSelectObject', {type = 'int32'}, {objectId = 'uint16'}, {model = 'int32'}, {position = 'vector3d'}} +OUTCOMING_RPCS[RPC.SELECTOBJECT] = {'onSendEnterEditObject', {type = 'int32'}, {objectId = 'uint16'}, {model = 'int32'}, {position = 'vector3d'}} +OUTCOMING_RPCS[RPC.SERVERCOMMAND] = {'onSendCommand', {command = 'string32'}} +OUTCOMING_RPCS[RPC.SPAWN] = {'onSendSpawn'} +OUTCOMING_RPCS[RPC.DEATH] = {'onSendDeathNotification', {reason = 'uint8'}, {killerId = 'uint16'}} +OUTCOMING_RPCS[RPC.DIALOGRESPONSE] = {'onSendDialogResponse', {dialogId = 'uint16'}, {button = 'uint8'}, {listboxId = 'uint16'}, {input = 'string8'}} +OUTCOMING_RPCS[RPC.CLICKTEXTDRAW] = {'onSendClickTextDraw', {textdrawId = 'uint16'}} +OUTCOMING_RPCS[RPC.SCMEVENT] = {'onSendVehicleTuningNotification', {vehicleId = 'int32'}, {param1 = 'int32'}, {param2 = 'int32'}, {event = 'int32'}} +OUTCOMING_RPCS[RPC.CHAT] = {'onSendChat', {message = 'string8'}} +OUTCOMING_RPCS[RPC.CLIENTCHECK] = {'onSendClientCheckResponse', {requestType = 'uint8'}, {result1 = 'int32'}, {result2 = 'uint8'}} +OUTCOMING_RPCS[RPC.DAMAGEVEHICLE] = {'onSendVehicleDamaged', {vehicleId = 'uint16'}, {panelDmg = 'int32'}, {doorDmg = 'int32'}, {lights = 'uint8'}, {tires = 'uint8'}} +OUTCOMING_RPCS[RPC.EDITATTACHEDOBJECT] = {'onSendEditAttachedObject', {response = 'int32'}, {index = 'int32'}, {model = 'int32'}, {bone = 'int32'}, {position = 'vector3d'}, {rotation = 'vector3d'}, {scale = 'vector3d'}, {color1 = 'int32'}, {color2 = 'int32'}} +OUTCOMING_RPCS[RPC.EDITOBJECT] = {'onSendEditObject', {playerObject = 'bool'}, {objectId = 'uint16'}, {response = 'int32'}, {position = 'vector3d'}, {rotation = 'vector3d'}} +OUTCOMING_RPCS[RPC.SETINTERIORID] = {'onSendInteriorChangeNotification', {interior = 'uint8'}} +OUTCOMING_RPCS[RPC.MAPMARKER] = {'onSendMapMarker', {position = 'vector3d'}} +OUTCOMING_RPCS[RPC.REQUESTCLASS] = {'onSendRequestClass', {classId = 'int32'}} +OUTCOMING_RPCS[RPC.REQUESTSPAWN] = {'onSendRequestSpawn'} +OUTCOMING_RPCS[RPC.PICKEDUPPICKUP] = {'onSendPickedUpPickup', {pickupId = 'int32'}} +OUTCOMING_RPCS[RPC.MENUSELECT] = {'onSendMenuSelect', {row = 'uint8'}} +OUTCOMING_RPCS[RPC.VEHICLEDESTROYED] = {'onSendVehicleDestroyed', {vehicleId = 'uint16'}} +OUTCOMING_RPCS[RPC.MENUQUIT] = {'onSendQuitMenu'} +OUTCOMING_RPCS[RPC.EXITVEHICLE] = {'onSendExitVehicle', {vehicleId = 'uint16'}} +OUTCOMING_RPCS[RPC.UPDATESCORESPINGSIPS] = {'onSendUpdateScoresAndPings'} +-- playerId = 'uint16', damage = 'float', weapon = 'int32', bodypart ='int32' +OUTCOMING_RPCS[RPC.GIVETAKEDAMAGE] = {{'onSendGiveDamage', 'onSendTakeDamage'}, handler.rpc_send_give_take_damage_reader, handler.rpc_send_give_take_damage_writer} +OUTCOMING_RPCS[RPC.SCRIPTCASH] = {'onSendMoneyIncreaseNotification', {amount = 'int32'}, {increaseType = 'int32'}} +OUTCOMING_RPCS[RPC.NPCJOIN] = {'onSendNPCJoin', {version = 'int32'}, {mod = 'uint8'}, {nickname = 'string8'}, {challengeResponse = 'int32'}} +OUTCOMING_RPCS[RPC.SRVNETSTATS] = {'onSendServerStatisticsRequest'} +OUTCOMING_RPCS[RPC.WEAPONPICKUPDESTROY] = {'onSendPickedUpWeapon', {id = 'uint16'}} +OUTCOMING_RPCS[RPC.CAMTARGETUPDATE] = {'onSendCameraTargetUpdate', {objectId = 'uint16'}, {vehicleId = 'uint16'}, {playerId = 'uint16'}, {actorId = 'uint16'}} +OUTCOMING_RPCS[RPC.GIVEACTORDAMAGE] = {'onSendGiveActorDamage', {_unused = 'bool'}, {actorId = 'uint16'}, {damage = 'float'}, {weapon = 'int32'}, {bodypart ='int32'}} + +-- Incoming rpcs +-- int playerId, string hostName, table settings, table vehicleModels, bool vehicleFriendlyFire +INCOMING_RPCS[RPC.INITGAME] = {'onInitGame', handler.rpc_init_game_reader, handler.rpc_init_game_writer} +INCOMING_RPCS[RPC.SERVERJOIN] = {'onPlayerJoin', {playerId = 'uint16'}, {color = 'int32'}, {isNpc = 'bool8'}, {nickname = 'string8'}} +INCOMING_RPCS[RPC.SERVERQUIT] = {'onPlayerQuit', {playerId = 'uint16'}, {reason = 'uint8'}} +INCOMING_RPCS[RPC.REQUESTCLASS] = {'onRequestClassResponse', {canSpawn = 'bool8'}, {team = 'uint8'}, {skin = 'int32'}, {_unused = 'uint8'}, {positon = 'vector3d'}, {rotation = 'float'}, {weapons = 'Int32Array3'}, {ammo = 'Int32Array3'}} +INCOMING_RPCS[RPC.REQUESTSPAWN] = {'onRequestSpawnResponse', {response = 'bool8'}} +INCOMING_RPCS[RPC.SETPLAYERNAME] = {'onSetPlayerName', {playerId = 'uint16'}, {name = 'string8'}, {success = 'bool8'}} +INCOMING_RPCS[RPC.SETPLAYERPOS] = {'onSetPlayerPos', {position = 'vector3d'}} +INCOMING_RPCS[RPC.SETPLAYERPOSFINDZ] = {'onSetPlayerPosFindZ', {position = 'vector3d'}} +INCOMING_RPCS[RPC.SETPLAYERHEALTH] = {'onSetPlayerHealth', {health = 'float'}} +INCOMING_RPCS[RPC.TOGGLEPLAYERCONTROLLABLE] = {'onTogglePlayerControllable', {controllable = 'bool8'}} +INCOMING_RPCS[RPC.PLAYSOUND] = {'onPlaySound', {soundId = 'int32'}, {position = 'vector3d'}} +INCOMING_RPCS[RPC.SETPLAYERWORLDBOUNDS] = {'onSetWorldBounds', {maxX = 'float'}, {minX = 'float'}, {maxY = 'float'}, {minY = 'float'}} +INCOMING_RPCS[RPC.GIVEPLAYERMONEY] = {'onGivePlayerMoney', {money = 'int32'}} +INCOMING_RPCS[RPC.SETPLAYERFACINGANGLE] = {'onSetPlayerFacingAngle', {angle = 'float'}} +INCOMING_RPCS[RPC.RESETPLAYERMONEY] = {'onResetPlayerMoney'} +INCOMING_RPCS[RPC.RESETPLAYERWEAPONS] = {'onResetPlayerWeapons'} +INCOMING_RPCS[RPC.GIVEPLAYERWEAPON] = {'onGivePlayerWeapon', {weaponId = 'int32'}, {ammo = 'int32'}} +INCOMING_RPCS[RPC.CANCELEDIT] = {'onCancelEdit'} +INCOMING_RPCS[RPC.SETPLAYERTIME] = {'onSetPlayerTime', {hour = 'uint8'}, {minute = 'uint8'}} +INCOMING_RPCS[RPC.TOGGLECLOCK] = {'onSetToggleClock', {state = 'bool8'}} +INCOMING_RPCS[RPC.WORLDPLAYERADD] = {'onPlayerStreamIn', {playerId = 'uint16'}, {team = 'uint8'}, {model = 'int32'}, {position = 'vector3d'}, {rotation = 'float'}, {color = 'int32'}, {fightingStyle = 'uint8'}} +INCOMING_RPCS[RPC.SETPLAYERSHOPNAME] = {'onSetShopName', {name = 'fixedString32'}} +INCOMING_RPCS[RPC.SETPLAYERSKILLLEVEL] = {'onSetPlayerSkillLevel', {playerId = 'uint16'}, {skill = 'int32'}, {level = 'uint16'}} +INCOMING_RPCS[RPC.SETPLAYERDRUNKLEVEL] = {'onSetPlayerDrunk', {drunkLevel = 'int32'}} +INCOMING_RPCS[RPC.CREATE3DTEXTLABEL] = {'onCreate3DText', {id = 'uint16'}, {color = 'int32'}, {position = 'vector3d'}, {distance = 'float'}, {testLOS = 'bool8'}, {attachedPlayerId = 'uint16'}, {attachedVehicleId = 'uint16'}, {text = 'encodedString4096'}} +INCOMING_RPCS[RPC.DISABLECHECKPOINT] = {'onDisableCheckpoint'} +INCOMING_RPCS[RPC.SETRACECHECKPOINT] = {'onSetRaceCheckpoint', {type = 'uint8'}, {position = 'vector3d'}, {nextPosition = 'vector3d'}, {size = 'float'}} +INCOMING_RPCS[RPC.DISABLERACECHECKPOINT] = {'onDisableRaceCheckpoint'} +INCOMING_RPCS[RPC.GAMEMODERESTART] = {'onGamemodeRestart'} +INCOMING_RPCS[RPC.PLAYAUDIOSTREAM] = {'onPlayAudioStream', {url = 'string8'}, {position = 'vector3d'}, {radius = 'float'}, {usePosition = 'bool8'}} +INCOMING_RPCS[RPC.STOPAUDIOSTREAM] = {'onStopAudioStream'} +INCOMING_RPCS[RPC.REMOVEBUILDINGFORPLAYER] = {'onRemoveBuilding', {modelId = 'int32'}, {position = 'vector3d'}, {radius = 'float'}} +INCOMING_RPCS[RPC.CREATEOBJECT] = {'onCreateObject', handler.rpc_create_object_reader, handler.rpc_create_object_writer} +INCOMING_RPCS[RPC.SETOBJECTPOS] = {'onSetObjectPosition', {objectId = 'uint16'}, {position = 'vector3d'}} +INCOMING_RPCS[RPC.SETOBJECTROT] = {'onSetObjectRotation', {objectId = 'uint16'}, {rotation = 'vector3d'}} +INCOMING_RPCS[RPC.DESTROYOBJECT] = {'onDestroyObject', {objectId = 'uint16'}} +INCOMING_RPCS[RPC.DEATHMESSAGE] = {'onPlayerDeathNotification', {killerId = 'uint16'}, {killedId = 'uint16'}, {reason = 'uint8'}} +INCOMING_RPCS[RPC.SETPLAYERMAPICON] = {'onSetMapIcon', {iconId = 'uint8'}, {position = 'vector3d'}, {type = 'uint8'}, {color = 'int32'}, {style = 'uint8'}} +INCOMING_RPCS[RPC.REMOVEVEHICLECOMPONENT] = {'onRemoveVehicleComponent', {vehicleId = 'uint16'}, {componentId = 'uint16'}} +INCOMING_RPCS[RPC.DESTROY3DTEXTLABEL] = {'onRemove3DTextLabel', {textLabelId = 'uint16'}} +INCOMING_RPCS[RPC.CHATBUBBLE] = {'onPlayerChatBubble', {playerId = 'uint16'}, {color = 'int32'}, {distance = 'float'}, {duration = 'int32'}, {message = 'string8'}} +INCOMING_RPCS[RPC.UPDATETIME] = {'onUpdateGlobalTimer', {time = 'int32'}} +INCOMING_RPCS[RPC.SHOWDIALOG] = {'onShowDialog', {dialogId = 'uint16'}, {style = 'uint8'}, {title = 'string8'}, {button1 = 'string8'}, {button2 = 'string8'}, {text = 'encodedString4096'}} +INCOMING_RPCS[RPC.DESTROYPICKUP] = {'onDestroyPickup', {id = 'int32'}} +INCOMING_RPCS[RPC.LINKVEHICLETOINTERIOR] = {'onLinkVehicleToInterior', {vehicleId = 'uint16'}, {interiorId = 'uint8'}} +INCOMING_RPCS[RPC.SETPLAYERARMOUR] = {'onSetPlayerArmour', {armour = 'float'}} +INCOMING_RPCS[RPC.SETPLAYERARMEDWEAPON] = {'onSetPlayerArmedWeapon', {weaponId = 'int32'}} +INCOMING_RPCS[RPC.SETSPAWNINFO] = {'onSetSpawnInfo', {team = 'uint8'}, {skin = 'int32'}, {_unused = 'uint8'}, {position = 'vector3d'}, {rotation = 'float'}, {weapons = 'Int32Array3'}, {ammo = 'Int32Array3'}} +INCOMING_RPCS[RPC.SETPLAYERTEAM] = {'onSetPlayerTeam', {playerId = 'uint16'}, {teamId = 'uint8'}} +INCOMING_RPCS[RPC.PUTPLAYERINVEHICLE] = {'onPutPlayerInVehicle', {vehicleId = 'uint16'}, {seatId = 'uint8'}} +INCOMING_RPCS[RPC.REMOVEPLAYERFROMVEHICLE] = {'onRemovePlayerFromVehicle'} +INCOMING_RPCS[RPC.SETPLAYERCOLOR] = {'onSetPlayerColor', {playerId = 'uint16'}, {color = 'int32'}} +INCOMING_RPCS[RPC.DISPLAYGAMETEXT] = {'onDisplayGameText', {style = 'int32'}, {time = 'int32'}, {text = 'string32'}} +INCOMING_RPCS[RPC.FORCECLASSSELECTION] = {'onForceClassSelection'} +INCOMING_RPCS[RPC.ATTACHOBJECTTOPLAYER] = {'onAttachObjectToPlayer', {objectId = 'uint16'}, {playerId = 'uint16'}, {offsets = 'vector3d'}, {rotation = 'vector3d'}} +-- menuId = 'uint8', menuTitle = 'fixedString32', x = 'float', y = 'float', twoColumns = 'bool32', columns = 'table', rows = 'table', menu = 'bool32' +INCOMING_RPCS[RPC.INITMENU] = {'onInitMenu', handler.rpc_init_menu_reader, handler.rpc_init_menu_writer} +INCOMING_RPCS[RPC.SHOWMENU] = {'onShowMenu', {menuId = 'uint8'}} +INCOMING_RPCS[RPC.HIDEMENU] = {'onHideMenu', {menuId = 'uint8'}} +INCOMING_RPCS[RPC.CREATEEXPLOSION] = {'onCreateExplosion', {position = 'vector3d'}, {style = 'int32'}, {radius = 'float'}} +INCOMING_RPCS[RPC.SHOWPLAYERNAMETAGFORPLAYER] = {'onShowPlayerNameTag', {playerId = 'uint16'}, {show = 'bool8'}} +INCOMING_RPCS[RPC.ATTACHCAMERATOOBJECT] = {'onAttachCameraToObject', {objectId = 'uint16'}} +INCOMING_RPCS[RPC.INTERPOLATECAMERA] = {'onInterpolateCamera', {setPos = 'bool'}, {fromPos = 'vector3d'}, {destPos = 'vector3d'}, {time = 'int32'}, {mode = 'uint8'}} +INCOMING_RPCS[RPC.GANGZONESTOPFLASH] = {'onGangZoneStopFlash', {zoneId = 'uint16'}} +INCOMING_RPCS[RPC.APPLYANIMATION] = {'onApplyPlayerAnimation', {playerId = 'uint16'}, {animLib = 'string8'}, {animName = 'string8'}, {frameDelta = 'float'}, {loop = 'bool'}, {lockX = 'bool'}, {lockY = 'bool'}, {freeze = 'bool'}, {time = 'int32'}} +INCOMING_RPCS[RPC.CLEARANIMATIONS] = {'onClearPlayerAnimation', {playerId = 'uint16'}} +INCOMING_RPCS[RPC.SETPLAYERSPECIALACTION] = {'onSetPlayerSpecialAction', {actionId = 'uint8'}} +INCOMING_RPCS[RPC.SETPLAYERFIGHTINGSTYLE] = {'onSetPlayerFightingStyle', {playerId = 'uint16'}, {styleId = 'uint8'}} +INCOMING_RPCS[RPC.SETPLAYERVELOCITY] = {'onSetPlayerVelocity', {velocity = 'vector3d'}} +INCOMING_RPCS[RPC.SETVEHICLEVELOCITY] = {'onSetVehicleVelocity', {turn = 'bool8'}, {velocity = 'vector3d'}} +INCOMING_RPCS[RPC.CLIENTMESSAGE] = {'onServerMessage', {color = 'int32'}, {text = 'string32'}} +INCOMING_RPCS[RPC.SETWORLDTIME] = {'onSetWorldTime', {hour = 'uint8'}} +INCOMING_RPCS[RPC.CREATEPICKUP] = {'onCreatePickup', {id = 'int32'}, {model = 'int32'}, {pickupType = 'int32'}, {position = 'vector3d'}} +INCOMING_RPCS[RPC.MOVEOBJECT] = {'onMoveObject', {objectId = 'uint16'}, {fromPos = 'vector3d'}, {destPos = 'vector3d'}, {speed = 'float'}, {rotation = 'vector3d'}} +INCOMING_RPCS[RPC.ENABLESTUNTBONUSFORPLAYER] = {'onEnableStuntBonus', {state = 'bool'}} +INCOMING_RPCS[RPC.TEXTDRAWSETSTRING] = {'onTextDrawSetString', {id = 'uint16'}, {text = 'string16'}} +INCOMING_RPCS[RPC.SETCHECKPOINT] = {'onSetCheckpoint', {position = 'vector3d'}, {radius = 'float'}} +INCOMING_RPCS[RPC.GANGZONECREATE] = {'onCreateGangZone', {zoneId = 'uint16'}, {squareStart = 'vector2d'}, {squareEnd = 'vector2d'}, {color = 'int32'}} +INCOMING_RPCS[RPC.PLAYCRIMEREPORT] = {'onPlayCrimeReport', {suspectId = 'uint16'}, {inVehicle = 'bool32'}, {vehicleModel = 'int32'}, {vehicleColor = 'int32'}, {crime = 'int32'}, {coordinates = 'vector3d'}} +INCOMING_RPCS[RPC.GANGZONEDESTROY] = {'onGangZoneDestroy', {zoneId = 'uint16'}} +INCOMING_RPCS[RPC.GANGZONEFLASH] = {'onGangZoneFlash', {zoneId = 'uint16'}, {color = 'int32'}} +INCOMING_RPCS[RPC.STOPOBJECT] = {'onStopObject', {objectId = 'uint16'}} +INCOMING_RPCS[RPC.SETNUMBERPLATE] = {'onSetVehicleNumberPlate', {vehicleId = 'uint16'}, {text = 'string8'}} +INCOMING_RPCS[RPC.TOGGLEPLAYERSPECTATING] = {'onTogglePlayerSpectating', {state = 'bool32'}} +INCOMING_RPCS[RPC.PLAYERSPECTATEPLAYER] = {'onSpectatePlayer', {playerId = 'uint16'}, {camType = 'uint8'}} +INCOMING_RPCS[RPC.PLAYERSPECTATEVEHICLE] = {'onSpectateVehicle', {vehicleId = 'uint16'}, {camType = 'uint8'}} +INCOMING_RPCS[RPC.SHOWTEXTDRAW] = {'onShowTextDraw', + {textdrawId = 'uint16'}, + {textdraw = { + {flags = 'uint8'}, + {letterWidth = 'float'}, + {letterHeight = 'float'}, + {letterColor = 'int32'}, + {lineWidth = 'float'}, + {lineHeight = 'float'}, + {boxColor = 'int32'}, + {shadow = 'uint8'}, + {outline = 'uint8'}, + {backgroundColor = 'int32'}, + {style = 'uint8'}, + {selectable = 'uint8'}, + {position = 'vector2d'}, + {modelId = 'uint16'}, + {rotation = 'vector3d'}, + {zoom = 'float'}, + {color = 'int32'}, + {text = 'string16'} + }} +} +INCOMING_RPCS[RPC.SETPLAYERWANTEDLEVEL] = {'onSetPlayerWantedLevel', {wantedLevel = 'uint8'}} +INCOMING_RPCS[RPC.TEXTDRAWHIDEFORPLAYER] = {'onTextDrawHide', {textDrawId = 'uint16'}} +INCOMING_RPCS[RPC.REMOVEPLAYERMAPICON] = {'onRemoveMapIcon', {iconId = 'uint8'}} +INCOMING_RPCS[RPC.SETPLAYERAMMO] = {'onSetWeaponAmmo', {weaponId = 'uint8'}, {ammo = 'uint16'}} +INCOMING_RPCS[RPC.SETGRAVITY] = {'onSetGravity', {gravity = 'float'}} +INCOMING_RPCS[RPC.SETVEHICLEHEALTH] = {'onSetVehicleHealth', {vehicleId = 'uint16'}, {health = 'float'}} +INCOMING_RPCS[RPC.ATTACHTRAILERTOVEHICLE] = {'onAttachTrailerToVehicle', {trailerId = 'uint16'}, {vehicleId = 'uint16'}} +INCOMING_RPCS[RPC.DETACHTRAILERFROMVEHICLE] = {'onDetachTrailerFromVehicle', {vehicleId = 'uint16'}} +INCOMING_RPCS[RPC.SETWEATHER] = {'onSetWeather', {weatherId = 'uint8'}} +INCOMING_RPCS[RPC.SETPLAYERSKIN] = {'onSetPlayerSkin', {playerId = 'int32'}, {skinId = 'int32'}} +INCOMING_RPCS[RPC.SETPLAYERINTERIOR] = {'onSetInterior', {interior = 'uint8'}} +INCOMING_RPCS[RPC.SETPLAYERCAMERAPOS] = {'onSetCameraPosition', {position = 'vector3d'}} +INCOMING_RPCS[RPC.SETPLAYERCAMERALOOKAT] = {'onSetCameraLookAt', {lookAtPosition = 'vector3d'}, {cutType = 'uint8'}} +INCOMING_RPCS[RPC.SETVEHICLEPOS] = {'onSetVehiclePosition', {vehicleId = 'uint16'}, {position = 'vector3d'}} +INCOMING_RPCS[RPC.SETVEHICLEZANGLE] = {'onSetVehicleAngle', {vehicleId = 'uint16'}, {angle = 'float'}} +INCOMING_RPCS[RPC.SETVEHICLEPARAMSFORPLAYER] = {'onSetVehicleParams', {vehicleId = 'uint16'}, {objective = 'bool8'}, {doorsLocked = 'bool8'}} +INCOMING_RPCS[RPC.SETCAMERABEHINDPLAYER] = {'onSetCameraBehind'} +INCOMING_RPCS[RPC.CHAT] = {'onChatMessage', {playerId = 'uint16'}, {text = 'string8'}} +INCOMING_RPCS[RPC.CONNECTIONREJECTED] = {'onConnectionRejected', {reason = 'uint8'}} +INCOMING_RPCS[RPC.WORLDPLAYERREMOVE] = {'onPlayerStreamOut', {playerId = 'uint16'}} +INCOMING_RPCS[RPC.WORLDVEHICLEADD] = {'onVehicleStreamIn', handler.rpc_vehicle_stream_in_reader, handler.rpc_vehicle_stream_in_writer} +INCOMING_RPCS[RPC.WORLDVEHICLEREMOVE] = {'onVehicleStreamOut', {vehicleId = 'uint16'}} +INCOMING_RPCS[RPC.WORLDPLAYERDEATH] = {'onPlayerDeath', {playerId = 'uint16'}} +INCOMING_RPCS[RPC.ENTERVEHICLE] = {'onPlayerEnterVehicle', {playerId = 'uint16'}, {vehicleId = 'uint16'}, {passenger = 'bool8'}} +INCOMING_RPCS[RPC.UPDATESCORESPINGSIPS] = {'onUpdateScoresAndPings', handler.rpc_update_scores_and_pings_reader, handler.rpc_update_scores_and_pings_writer} +INCOMING_RPCS[RPC.SETOBJECTMATERIAL] = {{'onSetObjectMaterial', 'onSetObjectMaterialText'}, handler.rpc_set_object_material_reader, handler.rpc_set_object_material_writer} +INCOMING_RPCS[RPC.CREATEACTOR] = {'onCreateActor', {actorId = 'uint16'}, {skinId = 'int32'}, {position = 'vector3d'}, {rotation = 'float'}, {health = 'float'}} +INCOMING_RPCS[RPC.CLICKTEXTDRAW] = {'onToggleSelectTextDraw', {state = 'bool'}, {hovercolor = 'int32'}} +INCOMING_RPCS[RPC.SETVEHICLEPARAMSEX] = {'onSetVehicleParamsEx', + {vehicleId = 'uint16'}, + {params = { + {engine = 'uint8'}, + {lights = 'uint8'}, + {alarm = 'uint8'}, + {doors = 'uint8'}, + {bonnet = 'uint8'}, + {boot = 'uint8'}, + {objective = 'uint8'}, + {unknown = 'uint8'} + }}, + {doors = { + {driver = 'uint8'}, + {passenger = 'uint8'}, + {backleft = 'uint8'}, + {backright = 'uint8'} + }}, + {windows = { + {driver = 'uint8'}, + {passenger = 'uint8'}, + {backleft = 'uint8'}, + {backright = 'uint8'} + }} +} +INCOMING_RPCS[RPC.SETPLAYERATTACHEDOBJECT] = {'onSetPlayerAttachedObject', + {playerId = 'uint16'}, + {index = 'int32'}, + {create = 'bool'}, + {object = { + {modelId = 'int32'}, + {bone = 'int32'}, + {offset = 'vector3d'}, + {rotation = 'vector3d'}, + {scale = 'vector3d'}, + {color1 = 'int32'}, + {color2 = 'int32'}} + } +} +INCOMING_RPCS[RPC.CLIENTCHECK] = {'onClientCheck', {requestType = 'uint8'}, {subject = 'int32'}, {offset = 'uint16'}, {length = 'uint16'}} +INCOMING_RPCS[RPC.DESTROYACTOR] = {'onDestroyActor', {actorId = 'uint16'}} +INCOMING_RPCS[RPC.DESTROYWEAPONPICKUP] = {'onDestroyWeaponPickup', {id = 'uint8'}} +INCOMING_RPCS[RPC.EDITATTACHEDOBJECT] = {'onEditAttachedObject', {index = 'int32'}} +INCOMING_RPCS[RPC.TOGGLECAMERATARGET] = {'onToggleCameraTargetNotifying', {enable = 'bool'}} +INCOMING_RPCS[RPC.SELECTOBJECT] = {'onEnterSelectObject'} +INCOMING_RPCS[RPC.EXITVEHICLE] = {'onPlayerExitVehicle', {playerId = 'uint16'}, {vehicleId = 'uint16'}} +INCOMING_RPCS[RPC.SCMEVENT] = {'onVehicleTuningNotification', {playerId = 'uint16'}, {event = 'int32'}, {vehicleId = 'int32'}, {param1 = 'int32'}, {param2 = 'int32'}} +INCOMING_RPCS[RPC.SRVNETSTATS] = {'onServerStatisticsResponse'} --, {data = 'RakNetStatisticsStruct'}} +INCOMING_RPCS[RPC.EDITOBJECT] = {'onEnterEditObject', {playerObject = 'bool'}, {objectId = 'uint16'}} +INCOMING_RPCS[RPC.DAMAGEVEHICLE] = {'onVehicleDamageStatusUpdate', {vehicleId = 'uint16'}, {panelDmg = 'int32'}, {doorDmg = 'int32'}, {lights = 'uint8'}, {tires = 'uint8'}} +INCOMING_RPCS[RPC.DISABLEVEHICLECOLLISIONS] = {'onDisableVehicleCollisions', {disable = 'bool'}} +INCOMING_RPCS[RPC.TOGGLEWIDESCREEN] = {'onToggleWidescreen', {enable = 'bool8'}} +INCOMING_RPCS[RPC.SETVEHICLETIRES] = {'onSetVehicleTires', {vehicleId = 'uint16'}, {tires = 'uint8'}} +INCOMING_RPCS[RPC.SETPLAYERDRUNKVISUALS] = {'onSetPlayerDrunkVisuals', {level = 'int32'}} +INCOMING_RPCS[RPC.SETPLAYERDRUNKHANDLING] = {'onSetPlayerDrunkHandling', {level = 'int32'}} +INCOMING_RPCS[RPC.APPLYACTORANIMATION] = {'onApplyActorAnimation', {actorId = 'uint16'}, {animLib = 'string8'}, {animName = 'string8'}, {frameDelta = 'float'}, {loop = 'bool'}, {lockX = 'bool'}, {lockY = 'bool'}, {freeze = 'bool'}, {time = 'int32'}} +INCOMING_RPCS[RPC.CLEARACTORANIMATION] = {'onClearActorAnimation', {actorId = 'uint16'}} +INCOMING_RPCS[RPC.SETACTORROTATION] = {'onSetActorFacingAngle', {actorId = 'uint16'}, {angle = 'float'}} +INCOMING_RPCS[RPC.SETACTORPOSITION] = {'onSetActorPos', {actorId = 'uint16'}, {position = 'vector3d'}} +INCOMING_RPCS[RPC.SETACTORHEALTH] = {'onSetActorHealth', {actorId = 'uint16'}, {health = 'float'}} +INCOMING_RPCS[RPC.SETPLAYEROBJECTNOCAMCOL] = {'onSetPlayerObjectNoCameraCol', {objectId = 'uint16'}} +INCOMING_RPCS[125] = {'_dummy125'} +INCOMING_RPCS[64] = {'_dummy64', {'uint16'}} +INCOMING_RPCS[48] = {'_unused48', {'int32'}} + + +-- Outgoing packets +OUTCOMING_PACKETS[PACKET.RCON_COMMAND] = {'onSendRconCommand', {command = 'string32'}} +OUTCOMING_PACKETS[PACKET.STATS_UPDATE] = {'onSendStatsUpdate', {money = 'int32'}, {drunkLevel = 'int32'}} +local function empty_writer() end +OUTCOMING_PACKETS[PACKET.PLAYER_SYNC] = {'onSendPlayerSync', function(bs) return utils.process_outcoming_sync_data(bs, 'PlayerSyncData') end, empty_writer} +OUTCOMING_PACKETS[PACKET.VEHICLE_SYNC] = {'onSendVehicleSync', function(bs) return utils.process_outcoming_sync_data(bs, 'VehicleSyncData') end, empty_writer} +OUTCOMING_PACKETS[PACKET.PASSENGER_SYNC] = {'onSendPassengerSync', function(bs) return utils.process_outcoming_sync_data(bs, 'PassengerSyncData') end, empty_writer} +OUTCOMING_PACKETS[PACKET.AIM_SYNC] = {'onSendAimSync', function(bs) return utils.process_outcoming_sync_data(bs, 'AimSyncData') end, empty_writer} +OUTCOMING_PACKETS[PACKET.UNOCCUPIED_SYNC] = {'onSendUnoccupiedSync', function(bs) return utils.process_outcoming_sync_data(bs, 'UnoccupiedSyncData') end, empty_writer} +OUTCOMING_PACKETS[PACKET.TRAILER_SYNC] = {'onSendTrailerSync', function(bs) return utils.process_outcoming_sync_data(bs, 'TrailerSyncData') end, empty_writer} +OUTCOMING_PACKETS[PACKET.BULLET_SYNC] = {'onSendBulletSync', function(bs) return utils.process_outcoming_sync_data(bs, 'BulletSyncData') end, empty_writer} +OUTCOMING_PACKETS[PACKET.SPECTATOR_SYNC] = {'onSendSpectatorSync', function(bs) return utils.process_outcoming_sync_data(bs, 'SpectatorSyncData') end, empty_writer} +OUTCOMING_PACKETS[PACKET.WEAPONS_UPDATE] = {'onSendWeaponsUpdate', handler.packet_weapons_update_reader, handler.packet_weapons_update_writer} +OUTCOMING_PACKETS[PACKET.AUTHENTICATION] = {'onSendAuthenticationResponse', {response = 'string8'}} + +-- Incoming packets +INCOMING_PACKETS[PACKET.PLAYER_SYNC] = {'onPlayerSync', handler.packet_player_sync_reader, handler.packet_player_sync_writer} +INCOMING_PACKETS[PACKET.VEHICLE_SYNC] = {'onVehicleSync', handler.packet_vehicle_sync_reader, handler.packet_vehicle_sync_writer} +INCOMING_PACKETS[PACKET.MARKERS_SYNC] = {'onMarkersSync', handler.packet_markers_sync_reader, handler.packet_markers_sync_writer} +INCOMING_PACKETS[PACKET.AIM_SYNC] = {'onAimSync', {playerId = 'uint16'}, {data = 'AimSyncData'}} +INCOMING_PACKETS[PACKET.BULLET_SYNC] = {'onBulletSync', {playerId = 'uint16'}, {data = 'BulletSyncData'}} +INCOMING_PACKETS[PACKET.UNOCCUPIED_SYNC] = {'onUnoccupiedSync', {playerId = 'uint16'}, {data = 'UnoccupiedSyncData'}} +INCOMING_PACKETS[PACKET.TRAILER_SYNC] = {'onTrailerSync', {playerId = 'uint16'}, {data = 'TrailerSyncData'}} +INCOMING_PACKETS[PACKET.PASSENGER_SYNC] = {'onPassengerSync', {playerId = 'uint16'}, {data = 'PassengerSyncData'}} +INCOMING_PACKETS[PACKET.AUTHENTICATION] = {'onAuthenticationRequest', {key = 'string8'}} +INCOMING_PACKETS[PACKET.CONNECTION_REQUEST_ACCEPTED] = {'onConnectionRequestAccepted', {ip = 'int32'}, {port = 'uint16'}, {playerId = 'uint16'}, {challenge = 'int32'}} +INCOMING_PACKETS[PACKET.CONNECTION_LOST] = {'onConnectionLost'} +INCOMING_PACKETS[PACKET.CONNECTION_BANNED] = {'onConnectionBanned'} +INCOMING_PACKETS[PACKET.CONNECTION_ATTEMPT_FAILED] = {'onConnectionAttemptFailed'} +INCOMING_PACKETS[PACKET.NO_FREE_INCOMING_CONNECTIONS] = {'onConnectionNoFreeSlot'} +INCOMING_PACKETS[PACKET.INVALID_PASSWORD] = {'onConnectionPasswordInvalid'} +INCOMING_PACKETS[PACKET.DISCONNECTION_NOTIFICATION] = {'onConnectionClosed'} + +return events diff --git a/samp/events/bitstream_io.lua b/samp/events/bitstream_io.lua new file mode 100644 index 0000000..b215392 --- /dev/null +++ b/samp/events/bitstream_io.lua @@ -0,0 +1,267 @@ +-- This file is part of the SAMP.Lua project. +-- Licensed under the MIT License. +-- Copyright (c) 2016, FYP @ BlastHack Team +-- https://github.com/THE-FYP/SAMP.Lua + +local mod = {} +local vector3d = require 'vector3d' +local ffi = require 'ffi' + +local function bitstream_read_fixed_string(bs, size) + local buf = ffi.new('uint8_t[?]', size + 1) + raknetBitStreamReadBuffer(bs, tonumber(ffi.cast('intptr_t', buf)), size) + buf[size] = 0 + -- Length is not specified to throw off trailing zeros. + return ffi.string(buf) +end + +local function bitstream_write_fixed_string(bs, str, size) + local buf = ffi.new('uint8_t[?]', size, string.sub(str, 1, size)) + raknetBitStreamWriteBuffer(bs, tonumber(ffi.cast('intptr_t', buf)), size) +end + +mod.bool = { + read = function(bs) return raknetBitStreamReadBool(bs) end, + write = function(bs, value) return raknetBitStreamWriteBool(bs, value) end +} + +mod.uint8 = { + read = function(bs) return raknetBitStreamReadInt8(bs) end, + write = function(bs, value) return raknetBitStreamWriteInt8(bs, value) end +} + +mod.uint16 = { + read = function(bs) return raknetBitStreamReadInt16(bs) end, + write = function(bs, value) return raknetBitStreamWriteInt16(bs, value) end +} + +mod.uint32 = { + read = function(bs) + local v = raknetBitStreamReadInt32(bs) + return v < 0 and 0x100000000 + v or v + end, + write = function(bs, value) + return raknetBitStreamWriteInt32(bs, value) + end +} + +mod.int8 = { + read = function(bs) + local v = raknetBitStreamReadInt8(bs) + return v >= 0x80 and v - 0x100 or v + end, + write = function(bs, value) + return raknetBitStreamWriteInt8(bs, value) + end +} + +mod.int16 = { + read = function(bs) + local v = raknetBitStreamReadInt16(bs) + return v >= 0x8000 and v - 0x10000 or v + end, + write = function(bs, value) + return raknetBitStreamWriteInt16(bs, value) + end +} + +mod.int32 = { + read = function(bs) return raknetBitStreamReadInt32(bs) end, + write = function(bs, value) return raknetBitStreamWriteInt32(bs, value) end +} + +mod.float = { + read = function(bs) return raknetBitStreamReadFloat(bs) end, + write = function(bs, value) return raknetBitStreamWriteFloat(bs, value) end +} + +mod.string8 = { + read = function(bs) + local len = raknetBitStreamReadInt8(bs) + if len <= 0 then return '' end + return raknetBitStreamReadString(bs, len) + end, + write = function(bs, value) + raknetBitStreamWriteInt8(bs, #value) + raknetBitStreamWriteString(bs, value) + end +} + +mod.string16 = { + read = function(bs) + local len = raknetBitStreamReadInt16(bs) + if len <= 0 then return '' end + return raknetBitStreamReadString(bs, len) + end, + write = function(bs, value) + raknetBitStreamWriteInt16(bs, #value) + raknetBitStreamWriteString(bs, value) + end +} + +mod.string32 = { + read = function(bs) + local len = raknetBitStreamReadInt32(bs) + if len <= 0 then return '' end + return raknetBitStreamReadString(bs, len) + end, + write = function(bs, value) + raknetBitStreamWriteInt32(bs, #value) + raknetBitStreamWriteString(bs, value) + end +} + +mod.bool8 = { + read = function(bs) + return raknetBitStreamReadInt8(bs) ~= 0 + end, + write = function(bs, value) + raknetBitStreamWriteInt8(bs, value == true and 1 or 0) + end +} + +mod.bool32 = { + read = function(bs) + return raknetBitStreamReadInt32(bs) ~= 0 + end, + write = function(bs, value) + raknetBitStreamWriteInt32(bs, value == true and 1 or 0) + end +} + +mod.int1 = { + read = function(bs) + if raknetBitStreamReadBool(bs) == true then return 1 else return 0 end + end, + write = function(bs, value) + raknetBitStreamWriteBool(bs, value ~= 0 and true or false) + end +} + +mod.fixedString32 = { + read = function(bs) + return bitstream_read_fixed_string(bs, 32) + end, + write = function(bs, value) + bitstream_write_fixed_string(bs, value, 32) + end +} + +mod.string256 = mod.fixedString32 + +mod.encodedString2048 = { + read = function(bs) return raknetBitStreamDecodeString(bs, 2048) end, + write = function(bs, value) raknetBitStreamEncodeString(bs, value) end +} + +mod.encodedString4096 = { + read = function(bs) return raknetBitStreamDecodeString(bs, 4096) end, + write = function(bs, value) raknetBitStreamEncodeString(bs, value) end +} + +mod.compressedFloat = { + read = function(bs) + return raknetBitStreamReadInt16(bs) / 32767.5 - 1 + end, + write = function(bs, value) + if value < -1 then + value = -1 + elseif value > 1 then + value = 1 + end + raknetBitStreamWriteInt16(bs, (value + 1) * 32767.5) + end +} + +mod.compressedVector = { + read = function(bs) + local magnitude = raknetBitStreamReadFloat(bs) + if magnitude ~= 0 then + local readCf = mod.compressedFloat.read + return vector3d(readCf(bs) * magnitude, readCf(bs) * magnitude, readCf(bs) * magnitude) + else + return vector3d(0, 0, 0) + end + end, + write = function(bs, data) + local x, y, z = data.x, data.y, data.z + local magnitude = math.sqrt(x * x + y * y + z * z) + raknetBitStreamWriteFloat(bs, magnitude) + if magnitude > 0 then + local writeCf = mod.compressedFloat.write + writeCf(bs, x / magnitude) + writeCf(bs, y / magnitude) + writeCf(bs, z / magnitude) + end + end +} + +mod.normQuat = { + read = function(bs) + local readBool, readShort = raknetBitStreamReadBool, raknetBitStreamReadInt16 + local cwNeg, cxNeg, cyNeg, czNeg = readBool(bs), readBool(bs), readBool(bs), readBool(bs) + local cx, cy, cz = readShort(bs), readShort(bs), readShort(bs) + local x = cx / 65535 + local y = cy / 65535 + local z = cz / 65535 + if cxNeg then x = -x end + if cyNeg then y = -y end + if czNeg then z = -z end + local diff = 1 - x * x - y * y - z * z + if diff < 0 then diff = 0 end + local w = math.sqrt(diff) + if cwNeg then w = -w end + return {w, x, y, z} + end, + write = function(bs, value) + local w, x, y, z = value[1], value[2], value[3], value[4] + raknetBitStreamWriteBool(bs, w < 0) + raknetBitStreamWriteBool(bs, x < 0) + raknetBitStreamWriteBool(bs, y < 0) + raknetBitStreamWriteBool(bs, z < 0) + raknetBitStreamWriteInt16(bs, math.abs(x) * 65535) + raknetBitStreamWriteInt16(bs, math.abs(y) * 65535) + raknetBitStreamWriteInt16(bs, math.abs(z) * 65535) + -- w is calculated on the target + end +} + +mod.vector3d = { + read = function(bs) + local x, y, z = + raknetBitStreamReadFloat(bs), + raknetBitStreamReadFloat(bs), + raknetBitStreamReadFloat(bs) + return vector3d(x, y, z) + end, + write = function(bs, value) + raknetBitStreamWriteFloat(bs, value.x) + raknetBitStreamWriteFloat(bs, value.y) + raknetBitStreamWriteFloat(bs, value.z) + end +} + +mod.vector2d = { + read = function(bs) + local x = raknetBitStreamReadFloat(bs) + local y = raknetBitStreamReadFloat(bs) + return {x = x, y = y} + end, + write = function(bs, value) + raknetBitStreamWriteFloat(bs, value.x) + raknetBitStreamWriteFloat(bs, value.y) + end +} + +local function bitstream_io_interface(field) + return setmetatable({}, { + __index = function(t, index) + return mod[index][field] + end + }) +end + +mod.bs_read = bitstream_io_interface('read') +mod.bs_write = bitstream_io_interface('write') + +return mod diff --git a/samp/events/core.lua b/samp/events/core.lua new file mode 100644 index 0000000..0d30174 --- /dev/null +++ b/samp/events/core.lua @@ -0,0 +1,141 @@ +-- This file is part of the SAMP.Lua project. +-- Licensed under the MIT License. +-- Copyright (c) 2016, FYP @ BlastHack Team +-- https://github.com/THE-FYP/SAMP.Lua + +local MODULE = { + MODULEINFO = { + name = 'samp.events', + version = 4 + }, + INTERFACE = { + OUTCOMING_RPCS = {}, + OUTCOMING_PACKETS = {}, + INCOMING_RPCS = {}, + INCOMING_PACKETS = {} + }, + EXPORTS = {} +} + +-- check dependencies +assert(isSampLoaded(), 'SA-MP is not loaded') +assert(isSampfuncsLoaded(), 'samp.events requires SAMPFUNCS') +assert(getMoonloaderVersion() >= 20, 'samp.events requires MoonLoader v.020 or greater') + +local BitStreamIO = require 'samp.events.bitstream_io' +MODULE.INTERFACE.BitStreamIO = BitStreamIO + + +local function read_data(bs, dataType) + if type(dataType) ~= 'table' then + return BitStreamIO[dataType].read(bs) + else -- process nested structures + local values = {} + for _, it in ipairs(dataType) do + local name, t = next(it) + values[name] = read_data(bs, t) + end + return values + end +end + +local function write_data(bs, dataType, value) + if type(dataType) ~= 'table' then + BitStreamIO[dataType].write(bs, value) + else -- process nested structures + for _, it in ipairs(dataType) do + local name, t = next(it) + write_data(bs, t, value[name]) + end + end +end + +local function process_event(bs, callback, struct, ignorebits) + local args = {} + if bs ~= 0 then + if ignorebits then + raknetBitStreamIgnoreBits(bs, ignorebits) + end + if type(struct[2]) == 'function' then + local r1, r2 = struct[2](bs) -- call custom reading function + if type(callback) == 'table' and type(r1) == 'string' then + callback = callback[r1] + if callback then + args = r2 + else + return + end + else + args = r1 + end + else + -- skip event name + for i = 2, #struct do + local _, t = next(struct[i]) -- type + table.insert(args, read_data(bs, t)) + end + end + end + local result = callback(unpack(args)) + if result == false then + return false -- consume packet + end + if bs ~= 0 and type(result) == 'table' then + raknetBitStreamSetWriteOffset(bs, ignorebits or 0) + if type(struct[3]) == 'function' then + struct[3](bs, result) -- call custom writing function + else + assert(#struct - 1 == #result) + for i = 2, #struct do + local _, t = next(struct[i]) -- type + write_data(bs, t, result[i - 1]) + end + end + end +end + +local function process_packet(id, bs, event_table, ignorebits) + local entry = event_table[id] + if entry ~= nil then + local key = entry[1] + local callback = nil + if type(key) == 'table' then + for i, name in ipairs(key) do + if type(MODULE[name]) == 'function' then + if not callback then callback = {} end + callback[name] = MODULE[name] + end + end + elseif type(MODULE[key]) == 'function' then + callback = MODULE[key] + end + if callback then + return process_event(bs, callback, entry, ignorebits) + end + end +end + + +local interface = MODULE.INTERFACE +local function samp_on_send_rpc(id, bitStream, priority, reliability, orderingChannel, shiftTs) + return process_packet(id, bitStream, interface.OUTCOMING_RPCS) +end + +local function samp_on_send_packet(id, bitStream, priority, reliability, orderingChannel) + return process_packet(id, bitStream, interface.OUTCOMING_PACKETS, 8) +end + +local function samp_on_receive_rpc(id, bitStream) + return process_packet(id, bitStream, interface.INCOMING_RPCS) +end + +local function samp_on_receive_packet(id, bitStream) + return process_packet(id, bitStream, interface.INCOMING_PACKETS, 8) +end + +addEventHandler('onSendRpc', samp_on_send_rpc) +addEventHandler('onSendPacket', samp_on_send_packet) +addEventHandler('onReceiveRpc', samp_on_receive_rpc) +addEventHandler('onReceivePacket', samp_on_receive_packet) + +return MODULE diff --git a/samp/events/extra_types.lua b/samp/events/extra_types.lua new file mode 100644 index 0000000..9525305 --- /dev/null +++ b/samp/events/extra_types.lua @@ -0,0 +1,45 @@ +-- This file is part of the SAMP.Lua project. +-- Licensed under the MIT License. +-- Copyright (c) 2016, FYP @ BlastHack Team +-- https://github.com/THE-FYP/SAMP.Lua + +local BitStreamIO = require 'samp.events.bitstream_io' +local utils = require 'samp.events.utils' + +BitStreamIO.Int32Array3 = { + read = function(bs) + local arr = {} + for i = 1, 3 do arr[i] = raknetBitStreamReadInt32(bs) end + return arr + end, + write = function(bs, value) + for i = 1, 3 do raknetBitStreamWriteInt32(bs, value[i]) end + end +} + +BitStreamIO.AimSyncData = { + read = function(bs) return utils.read_sync_data(bs, 'AimSyncData') end, + write = function(bs, value) utils.write_sync_data(bs, 'AimSyncData', value) end +} + +BitStreamIO.UnoccupiedSyncData = { + read = function(bs) return utils.read_sync_data(bs, 'UnoccupiedSyncData') end, + write = function(bs, value) utils.write_sync_data(bs, 'UnoccupiedSyncData', value) end +} + +BitStreamIO.PassengerSyncData = { + read = function(bs) return utils.read_sync_data(bs, 'PassengerSyncData') end, + write = function(bs, value) utils.write_sync_data(bs, 'PassengerSyncData', value) end +} + +BitStreamIO.BulletSyncData = { + read = function(bs) return utils.read_sync_data(bs, 'BulletSyncData') end, + write = function(bs, value) utils.write_sync_data(bs, 'BulletSyncData', value) end +} + +BitStreamIO.TrailerSyncData = { + read = function(bs) return utils.read_sync_data(bs, 'TrailerSyncData') end, + write = function(bs, value) utils.write_sync_data(bs, 'TrailerSyncData', value) end +} + +return BitStreamIO diff --git a/samp/events/handlers.lua b/samp/events/handlers.lua new file mode 100644 index 0000000..1029acf --- /dev/null +++ b/samp/events/handlers.lua @@ -0,0 +1,526 @@ +-- This file is part of the SAMP.Lua project. +-- Licensed under the MIT License. +-- Copyright (c) 2016, FYP @ BlastHack Team +-- https://github.com/THE-FYP/SAMP.Lua + +local bs_io = require 'samp.events.bitstream_io' +local utils = require 'samp.events.utils' +local bsread, bswrite = bs_io.bs_read, bs_io.bs_write +local handler = {} + +--- onSendGiveDamage, onSendTakeDamage +function handler.rpc_send_give_take_damage_reader(bs) + local take = bsread.bool(bs) -- 'true' is take damage + local data = { + bsread.uint16(bs), -- playerId + bsread.float(bs), -- damage + bsread.int32(bs), -- weapon + bsread.int32(bs), -- bodypart + take, + } + return (take and 'onSendTakeDamage' or 'onSendGiveDamage'), data +end + +function handler.rpc_send_give_take_damage_writer(bs, data) + bswrite.bool(bs, data[5]) -- give or take + bswrite.uint16(bs, data[1]) -- playerId + bswrite.float(bs, data[2]) -- damage + bswrite.int32(bs, data[3]) -- weapon + bswrite.int32(bs, data[4]) -- bodypart +end + +--- onInitGame +function handler.rpc_init_game_reader(bs) + local settings = {} + settings.zoneNames = bsread.bool(bs) + settings.useCJWalk = bsread.bool(bs) + settings.allowWeapons = bsread.bool(bs) + settings.limitGlobalChatRadius = bsread.bool(bs) + settings.globalChatRadius = bsread.float(bs) + settings.stuntBonus = bsread.bool(bs) + settings.nametagDrawDist = bsread.float(bs) + settings.disableEnterExits = bsread.bool(bs) + settings.nametagLOS = bsread.bool(bs) + settings.tirePopping = bsread.bool(bs) + settings.classesAvailable = bsread.int32(bs) + local playerId = bsread.uint16(bs) + settings.showPlayerTags = bsread.bool(bs) + settings.playerMarkersMode = bsread.int32(bs) + settings.worldTime = bsread.uint8(bs) + settings.worldWeather = bsread.uint8(bs) + settings.gravity = bsread.float(bs) + settings.lanMode = bsread.bool(bs) + settings.deathMoneyDrop = bsread.int32(bs) + settings.instagib = bsread.bool(bs) + settings.normalOnfootSendrate = bsread.int32(bs) + settings.normalIncarSendrate = bsread.int32(bs) + settings.normalFiringSendrate = bsread.int32(bs) + settings.sendMultiplier = bsread.int32(bs) + settings.lagCompMode = bsread.int32(bs) + local hostName = bsread.string8(bs) + local vehicleModels = {} + for i = 0, 212 - 1 do + vehicleModels[i] = bsread.uint8(bs) + end + settings.vehicleFriendlyFire = bsread.bool32(bs) + return {playerId, hostName, settings, vehicleModels, settings.vehicleFriendlyFire} +end + +function handler.rpc_init_game_writer(bs, data) + local settings = data[3] + local vehicleModels = data[4] + bswrite.bool(bs, settings.zoneNames) + bswrite.bool(bs, settings.useCJWalk) + bswrite.bool(bs, settings.allowWeapons) + bswrite.bool(bs, settings.limitGlobalChatRadius) + bswrite.float(bs, settings.globalChatRadius) + bswrite.bool(bs, settings.stuntBonus) + bswrite.float(bs, settings.nametagDrawDist) + bswrite.bool(bs, settings.disableEnterExits) + bswrite.bool(bs, settings.nametagLOS) + bswrite.bool(bs, settings.tirePopping) + bswrite.int32(bs, settings.classesAvailable) + bswrite.uint16(bs, data[1]) -- playerId + bswrite.bool(bs, settings.showPlayerTags) + bswrite.int32(bs, settings.playerMarkersMode) + bswrite.uint8(bs, settings.worldTime) + bswrite.uint8(bs, settings.worldWeather) + bswrite.float(bs, settings.gravity) + bswrite.bool(bs, settings.lanMode) + bswrite.int32(bs, settings.deathMoneyDrop) + bswrite.bool(bs, settings.instagib) + bswrite.int32(bs, settings.normalOnfootSendrate) + bswrite.int32(bs, settings.normalIncarSendrate) + bswrite.int32(bs, settings.normalFiringSendrate) + bswrite.int32(bs, settings.sendMultiplier) + bswrite.int32(bs, settings.lagCompMode) + bswrite.string8(bs, data[2]) -- hostName + for i = 1, 212 do + bswrite.uint8(bs, vehicleModels[i]) + end + bswrite.bool32(bs, settings.vehicleFriendlyFire) +end + +--- onInitMenu +function handler.rpc_init_menu_reader(bs) + local colWidth2 + local rows = {} + local columns = {} + local readColumn = function(width) + local title = bsread.fixedString32(bs) + local rowCount = bsread.uint8(bs) + local column = {title = title, width = width, text = {}} + for i = 1, rowCount do + column.text[i] = bsread.fixedString32(bs) + end + return column + end + local menuId = bsread.uint8(bs) + local twoColumns = bsread.bool32(bs) + local menuTitle = bsread.fixedString32(bs) + local x = bsread.float(bs) + local y = bsread.float(bs) + local colWidth1 = bsread.float(bs) + if twoColumns then + colWidth2 = bsread.float(bs) + end + local menu = bsread.bool32(bs) + for i = 1, 12 do + rows[i] = bsread.int32(bs) + end + columns[1] = readColumn(colWidth1) + if twoColumns then + columns[2] = readColumn(colWidth2) + end + return {menuId, menuTitle, x, y, twoColumns, columns, rows, menu} +end + +function handler.rpc_init_menu_writer(bs, data) + local columns = data[6] + bswrite.uint8(bs, data[1]) -- menuId + bswrite.bool32(bs, data[5]) -- twoColumns + bswrite.fixedString32(bs, data[2]) -- title + bswrite.float(bs, data[3]) -- x + bswrite.float(bs, data[4]) -- y + -- columns width + bswrite.float(bs, columns[1].width) + if data[5] then + bswrite.float(bs, columns[2].width) + end + bswrite.bool32(bs, data[8]) -- menu + -- rows + for i = 1, 12 do + bswrite.int32(bs, data[7][i]) + end + -- columns + for i = 1, (data[5] and 2 or 1) do + bswrite.fixedString32(bs, columns[i].title) + bswrite.uint8(bs, #columns[i].text) + for r, t in ipairs(columns[i].text) do + bswrite.fixedString32(bs, t) + end + end +end + +--- onMarkersSync +function handler.packet_markers_sync_reader(bs) + local markers = {} + local players = bsread.int32(bs) + for i = 1, players do + local playerId = bsread.uint16(bs) + local active = bsread.bool(bs) + if active then + local vector3d = require 'vector3d' + local x, y, z = bsread.int16(bs), bsread.int16(bs), bsread.int16(bs) + table.insert(markers, {playerId = playerId, active = true, coords = vector3d(x, y, z)}) + else + table.insert(markers, {playerId = playerId, active = false}) + end + end + return {markers} +end + +function handler.packet_markers_sync_writer(bs, data) + bswrite.int32(bs, #data) + for i = 1, #data do + local it = data[i] + bswrite.uint16(bs, it.playerId) + bswrite.bool(bs, it.active) + if it.active then + bswrite.uint16(bs, it.coords.x) + bswrite.uint16(bs, it.coords.y) + bswrite.uint16(bs, it.coords.z) + end + end +end + +--- onPlayerSync +function handler.packet_player_sync_reader(bs) + local has_value = bsread.bool + local data = {} + local playerId = bsread.uint16(bs) + if has_value(bs) then data.leftRightKeys = bsread.uint16(bs) end + if has_value(bs) then data.upDownKeys = bsread.uint16(bs) end + data.keysData = bsread.uint16(bs) + data.position = bsread.vector3d(bs) + data.quaternion = bsread.normQuat(bs) + data.health, data.armor = utils.decompress_health_and_armor(bsread.uint8(bs)) + data.weapon = bsread.uint8(bs) + data.specialAction = bsread.uint8(bs) + data.moveSpeed = bsread.compressedVector(bs) + if has_value(bs) then + data.surfingVehicleId = bsread.uint16(bs) + data.surfingOffsets = bsread.vector3d(bs) + end + if has_value(bs) then + data.animationId = bsread.uint16(bs) + data.animationFlags = bsread.uint16(bs) + end + return {playerId, data} +end + +function handler.packet_player_sync_writer(bs, data) + local playerId = data[1] + local data = data[2] + bswrite.uint16(bs, playerId) + bswrite.bool(bs, data.leftRightKeys ~= nil) + if data.leftRightKeys then bswrite.uint16(bs, data.leftRightKeys) end + bswrite.bool(bs, data.upDownKeys ~= nil) + if data.upDownKeys then bswrite.uint16(bs, data.upDownKeys) end + bswrite.uint16(bs, data.keysData) + bswrite.vector3d(bs, data.position) + bswrite.normQuat(bs, data.quaternion) + bswrite.uint8(bs, utils.compress_health_and_armor(data.health, data.armor)) + bswrite.uint8(bs, data.weapon) + bswrite.uint8(bs, data.specialAction) + bswrite.compressedVector(bs, data.moveSpeed) + bswrite.bool(bs, data.surfingVehicleId ~= nil) + if data.surfingVehicleId then + bswrite.uint16(bs, data.surfingVehicleId) + bswrite.vector3d(bs, data.surfingOffsets) + end + bswrite.bool(bs, data.animationId ~= nil) + if data.animationId then + bswrite.uint16(bs, data.animationId) + bswrite.uint16(bs, data.animationFlags) + end +end + +--- onVehicleSync +function handler.packet_vehicle_sync_reader(bs) + local data = {} + local playerId = bsread.uint16(bs) + local vehicleId = bsread.uint16(bs) + data.leftRightKeys = bsread.uint16(bs) + data.upDownKeys = bsread.uint16(bs) + data.keysData = bsread.uint16(bs) + data.quaternion = bsread.normQuat(bs) + data.position = bsread.vector3d(bs) + data.moveSpeed = bsread.compressedVector(bs) + data.vehicleHealth = bsread.uint16(bs) + data.playerHealth, data.armor = utils.decompress_health_and_armor(bsread.uint8(bs)) + data.currentWeapon = bsread.uint8(bs) + data.siren = bsread.bool(bs) + data.landingGear = bsread.bool(bs) + if bsread.bool(bs) then + data.trainSpeed = bsread.int32(bs) + end + if bsread.bool(bs) then + data.trailerId = bsread.uint16(bs) + end + return {playerId, vehicleId, data} +end + +function handler.packet_vehicle_sync_writer(bs, data) + local playerId = data[1] + local vehicleId = data[2] + local data = data[3] + bswrite.uint16(bs, playerId) + bswrite.uint16(bs, vehicleId) + bswrite.uint16(bs, data.leftRightKeys) + bswrite.uint16(bs, data.upDownKeys) + bswrite.uint16(bs, data.keysData) + bswrite.normQuat(bs, data.quaternion) + bswrite.vector3d(bs, data.position) + bswrite.compressedVector(bs, data.moveSpeed) + bswrite.uint16(bs, data.vehicleHealth) + bswrite.uint8(bs, utils.compress_health_and_armor(data.playerHealth, data.armor)) + bswrite.uint8(bs, data.currentWeapon) + bswrite.bool(bs, data.siren) + bswrite.bool(bs, data.landingGear) + bswrite.bool(bs, data.trainSpeed ~= nil) + if data.trainSpeed ~= nil then + bswrite.int32(bs, data.trainSpeed) + end + bswrite.bool(bs, data.trailerId ~= nil) + if data.trailerId ~= nil then + bswrite.uint16(bs, data.trailerId) + end +end + +--- onVehicleStreamIn +function handler.rpc_vehicle_stream_in_reader(bs) + local data = {modSlots = {}} + local vehicleId = bsread.uint16(bs) + data.type = bsread.int32(bs) + data.position = bsread.vector3d(bs) + data.rotation = bsread.float(bs) + data.bodyColor1 = bsread.uint8(bs) + data.bodyColor2 = bsread.uint8(bs) + data.health = bsread.float(bs) + data.interiorId = bsread.uint8(bs) + data.doorDamageStatus = bsread.int32(bs) + data.panelDamageStatus = bsread.int32(bs) + data.lightDamageStatus = bsread.uint8(bs) + data.tireDamageStatus = bsread.uint8(bs) + data.addSiren = bsread.uint8(bs) + for i = 1, 14 do + data.modSlots[i] = bsread.uint8(bs) + end + data.paintJob = bsread.uint8(bs) + data.interiorColor1 = bsread.int32(bs) + data.interiorColor2 = bsread.int32(bs) + return {vehicleId, data} +end + +function handler.rpc_vehicle_stream_in_writer(bs, data) + local vehicleId = data[1] + local data = data[2] + bswrite.uint16(bs, vehicleId) + bswrite.int32(bs, data.type) + bswrite.vector3d(bs, data.position) + bswrite.float(bs, data.rotation) + bswrite.uint8(bs, data.bodyColor1) + bswrite.uint8(bs, data.bodyColor2) + bswrite.float(bs, data.health) + bswrite.uint8(bs, data.interiorId) + bswrite.int32(bs, data.doorDamageStatus) + bswrite.int32(bs, data.panelDamageStatus) + bswrite.uint8(bs, data.lightDamageStatus) + bswrite.uint8(bs, data.tireDamageStatus) + bswrite.uint8(bs, data.addSiren) + for i = 1, 14 do + bswrite.uint8(bs, data.modSlots[i]) + end + bswrite.uint8(bs, data.paintJob) + bswrite.int32(bs, data.interiorColor1) + bswrite.int32(bs, data.interiorColor2) +end + +local MATERIAL_TYPE = { + NONE = 0, + TEXTURE = 1, + TEXT = 2, +} + +local function read_object_material(bs) + local data = {} + data.materialId = bsread.uint8(bs) + data.modelId = bsread.uint16(bs) + data.libraryName = bsread.string8(bs) + data.textureName = bsread.string8(bs) + data.color = bsread.int32(bs) + data.type = MATERIAL_TYPE.TEXTURE + return data +end + +local function write_object_material(bs, data) + bswrite.uint8(bs, data.type) + bswrite.uint8(bs, data.materialId) + bswrite.uint16(bs, data.modelId) + bswrite.string8(bs, data.libraryName) + bswrite.string8(bs, data.textureName) + bswrite.int32(bs, data.color) +end + +local function read_object_material_text(bs) + local data = {} + data.materialId = bsread.uint8(bs) + data.materialSize = bsread.uint8(bs) + data.fontName = bsread.string8(bs) + data.fontSize = bsread.uint8(bs) + data.bold = bsread.uint8(bs) + data.fontColor = bsread.int32(bs) + data.backGroundColor = bsread.int32(bs) + data.align = bsread.uint8(bs) + data.text = bsread.encodedString2048(bs) + data.type = MATERIAL_TYPE.TEXT + return data +end + +local function write_object_material_text(bs, data) + bswrite.uint8(bs, data.type) + bswrite.uint8(bs, data.materialId) + bswrite.uint8(bs, data.materialSize) + bswrite.string8(bs, data.fontName) + bswrite.uint8(bs, data.fontSize) + bswrite.uint8(bs, data.bold) + bswrite.int32(bs, data.fontColor) + bswrite.int32(bs, data.backGroundColor) + bswrite.uint8(bs, data.align) + bswrite.encodedString2048(bs, data.text) +end + +--- onSetObjectMaterial +function handler.rpc_set_object_material_reader(bs) + local objectId = bsread.uint16(bs) + local materialType = bsread.uint8(bs) + local material + if materialType == MATERIAL_TYPE.TEXTURE then + material = read_object_material(bs) + elseif materialType == MATERIAL_TYPE.TEXT then + material = read_object_material_text(bs) + end + local ev = materialType == MATERIAL_TYPE.TEXTURE and 'onSetObjectMaterial' or 'onSetObjectMaterialText' + return ev, {objectId, material} +end + +function handler.rpc_set_object_material_writer(bs, data) + local objectId = data[1] + local mat = data[2] + bswrite.uint16(bs, objectId) + if mat.type == MATERIAL_TYPE.TEXTURE then + write_object_material(bs, mat) + elseif mat.type == MATERIAL_TYPE.TEXT then + write_object_material_text(bs, mat) + end +end + +--- onCreateObject +function handler.rpc_create_object_reader(bs) + local data = {materials = {}, materialText = {}} + local objectId = bsread.uint16(bs) + data.modelId = bsread.int32(bs) + data.position = bsread.vector3d(bs) + data.rotation = bsread.vector3d(bs) + data.drawDistance = bsread.float(bs) + data.noCameraCol = bsread.bool8(bs) + data.attachToVehicleId = bsread.uint16(bs) + data.attachToObjectId = bsread.uint16(bs) + if data.attachToVehicleId ~= 0xFFFF or data.attachToObjectId ~= 0xFFFF then + data.attachOffsets = bsread.vector3d(bs) + data.attachRotation = bsread.vector3d(bs) + data.syncRotation = bsread.bool8(bs) + end + data.texturesCount = bsread.uint8(bs) + while raknetBitStreamGetNumberOfUnreadBits(bs) >= 8 do + local materialType = bsread.uint8(bs) + if materialType == MATERIAL_TYPE.TEXTURE then + table.insert(data.materials, read_object_material(bs)) + elseif materialType == MATERIAL_TYPE.TEXT then + table.insert(data.materialText, read_object_material_text(bs)) + end + end + data.materials_text = data.materialText -- obsolete + return {objectId, data} +end + +function handler.rpc_create_object_writer(bs, data) + local objectId = data[1] + local data = data[2] + bswrite.uint16(bs, objectId) + bswrite.int32(bs, data.modelId) + bswrite.vector3d(bs, data.position) + bswrite.vector3d(bs, data.rotation) + bswrite.float(bs, data.drawDistance) + bswrite.bool8(bs, data.noCameraCol) + bswrite.uint16(bs, data.attachToVehicleId) + bswrite.uint16(bs, data.attachToObjectId) + if data.attachToVehicleId ~= 0xFFFF or data.attachToObjectId ~= 0xFFFF then + bswrite.vector3d(bs, data.attachOffsets) + bswrite.vector3d(bs, data.attachRotation) + bswrite.bool8(bs, data.syncRotation) + end + bswrite.uint8(bs, data.texturesCount) + for _, it in ipairs(data.materials) do + write_object_material(bs, it) + end + for _, it in ipairs(data.materialText) do + write_object_material_text(bs, it) + end +end + +function handler.rpc_update_scores_and_pings_reader(bs) + local data = {} + for i = 1, raknetBitStreamGetNumberOfBytesUsed(bs) / 10 do + local playerId = bsread.uint16(bs) + local playerScore = bsread.int32(bs) + local playerPing = bsread.int32(bs) + data[playerId] = {score = playerScore, ping = playerPing} + end + return {data} +end + +function handler.rpc_update_scores_and_pings_writer(bs, data) + for id, info in pairs(data[1]) do + bswrite.uint16(bs, id) + bswrite.int32(bs, info.score) + bswrite.int32(bs, info.ping) + end +end + +function handler.packet_weapons_update_reader(bs) + local playerTarget = bsread.uint16(bs) + local actorTarget = bsread.uint16(bs) + local weapons = {} + local count = raknetBitStreamGetNumberOfUnreadBits(bs) / 32 + for i = 1, count do + local slot = bsread.uint8(bs) + local weapon = bsread.uint8(bs) + local ammo = bsread.uint16(bs) + weapons[i] = {slot = slot, weapon = weapon, ammo = ammo} + end + return {playerTarget, actorTarget, weapons} +end + +function handler.packet_weapons_update_writer(bs, data) + bswrite.uint16(bs, data[1]) + bswrite.uint16(bs, data[2]) + for i, weap in ipairs(data[3]) do + bswrite.uint8(bs, weap.slot) + bswrite.uint8(bs, weap.weapon) + bswrite.uint16(bs, weap.ammo) + end +end + +return handler diff --git a/samp/events/utils.lua b/samp/events/utils.lua new file mode 100644 index 0000000..81960a4 --- /dev/null +++ b/samp/events/utils.lua @@ -0,0 +1,45 @@ +-- This file is part of the SAMP.Lua project. +-- Licensed under the MIT License. +-- Copyright (c) 2016, FYP @ BlastHack Team +-- https://github.com/THE-FYP/SAMP.Lua + +local ffi = require 'ffi' +local utils = {} + +function utils.decompress_health_and_armor(hpAp) + local hp = math.min(bit.rshift(hpAp, 4) * 7, 100) + local armor = math.min(bit.band(hpAp, 0x0F) * 7, 100) + return hp, armor +end + +function utils.compress_health_and_armor(health, armor) + local hp = health >= 100 and 0xF0 or bit.lshift(health / 7, 4) + local ap = armor >= 100 and 0x0F or bit.band(armor / 7, 0x0F) + return bit.bor(hp, ap) +end + +function utils.create_sync_data(st) + require 'samp.synchronization' + return ffi.new(st) +end + +function utils.read_sync_data(bs, st) + local dataStruct = utils.create_sync_data(st) + local ptr = tonumber(ffi.cast('intptr_t', ffi.cast('void*', dataStruct))) + raknetBitStreamReadBuffer(bs, ptr, ffi.sizeof(dataStruct)) + return dataStruct +end + +function utils.write_sync_data(bs, st, ffiobj) + require 'samp.synchronization' + local ptr = tonumber(ffi.cast('intptr_t', ffi.cast('void*', ffiobj))) + raknetBitStreamWriteBuffer(bs, ptr, ffi.sizeof(st)) +end + +function utils.process_outcoming_sync_data(bs, st) + local data = raknetBitStreamGetDataPtr(bs) + 1 + require 'samp.synchronization' + return {ffi.cast(st .. '*', data)} +end + +return utils diff --git a/samp/raknet.lua b/samp/raknet.lua new file mode 100644 index 0000000..8e004dc --- /dev/null +++ b/samp/raknet.lua @@ -0,0 +1,236 @@ +-- This file is part of the SAMP.Lua project. +-- Licensed under the MIT License. +-- Copyright (c) 2016, FYP @ BlastHack Team +-- https://github.com/THE-FYP/SAMP.Lua + +local mod = +{ + MODULEINFO = { + name = 'samp.raknet', + version = 2 + } +} +require 'sampfuncs' + +mod.RPC = { + CLICKPLAYER = RPC_CLICKPLAYER, + CLIENTJOIN = RPC_CLIENTJOIN, + ENTERVEHICLE = RPC_ENTERVEHICLE, + SCRIPTCASH = RPC_SCRIPTCASH, + SERVERCOMMAND = RPC_SERVERCOMMAND, + SPAWN = RPC_SPAWN, + DEATH = RPC_DEATH, + NPCJOIN = RPC_NPCJOIN, + DIALOGRESPONSE = RPC_DIALOGRESPONSE, + CLICKTEXTDRAW = RPC_CLICKTEXTDRAW, + SCMEVENT = RPC_SCMEVENT, + WEAPONPICKUPDESTROY = RPC_WEAPONPICKUPDESTROY, + CHAT = RPC_CHAT, + SRVNETSTATS = RPC_SRVNETSTATS, + CLIENTCHECK = RPC_CLIENTCHECK, + DAMAGEVEHICLE = RPC_DAMAGEVEHICLE, + GIVETAKEDAMAGE = RPC_GIVETAKEDAMAGE, + EDITATTACHEDOBJECT = RPC_EDITATTACHEDOBJECT, + EDITOBJECT = RPC_EDITOBJECT, + SETINTERIORID = RPC_SETINTERIORID, + MAPMARKER = RPC_MAPMARKER, + REQUESTCLASS = RPC_REQUESTCLASS, + REQUESTSPAWN = RPC_REQUESTSPAWN, + PICKEDUPPICKUP = RPC_PICKEDUPPICKUP, + MENUSELECT = RPC_MENUSELECT, + VEHICLEDESTROYED = RPC_VEHICLEDESTROYED, + MENUQUIT = RPC_MENUQUIT, + EXITVEHICLE = RPC_EXITVEHICLE, + UPDATESCORESPINGSIPS = RPC_UPDATESCORESPINGSIPS, + CAMTARGETUPDATE = 168, + GIVEACTORDAMAGE = 177, + + CONNECTIONREJECTED = 130, + SETPLAYERNAME = RPC_SCRSETPLAYERNAME, + SETPLAYERPOS = RPC_SCRSETPLAYERPOS, + SETPLAYERPOSFINDZ = RPC_SCRSETPLAYERPOSFINDZ, + SETPLAYERHEALTH = RPC_SCRSETPLAYERHEALTH, + TOGGLEPLAYERCONTROLLABLE = RPC_SCRTOGGLEPLAYERCONTROLLABLE, + PLAYSOUND = RPC_SCRPLAYSOUND, + SETPLAYERWORLDBOUNDS = RPC_SCRSETPLAYERWORLDBOUNDS, + GIVEPLAYERMONEY = RPC_SCRGIVEPLAYERMONEY, + SETPLAYERFACINGANGLE = RPC_SCRSETPLAYERFACINGANGLE, + RESETPLAYERMONEY = RPC_SCRRESETPLAYERMONEY, + RESETPLAYERWEAPONS = RPC_SCRRESETPLAYERWEAPONS, + GIVEPLAYERWEAPON = RPC_SCRGIVEPLAYERWEAPON, + SETVEHICLEPARAMSEX = RPC_SCRSETVEHICLEPARAMSEX, + CANCELEDIT = RPC_SCRCANCELEDIT, + SETPLAYERTIME = RPC_SCRSETPLAYERTIME, + TOGGLECLOCK = RPC_SCRTOGGLECLOCK, + WORLDPLAYERADD = RPC_SCRWORLDPLAYERADD, + SETPLAYERSHOPNAME = RPC_SCRSETPLAYERSHOPNAME, + SETPLAYERSKILLLEVEL = RPC_SCRSETPLAYERSKILLLEVEL, + SETPLAYERDRUNKLEVEL = RPC_SCRSETPLAYERDRUNKLEVEL, + CREATE3DTEXTLABEL = RPC_SCRCREATE3DTEXTLABEL, + DISABLECHECKPOINT = RPC_SCRDISABLECHECKPOINT, + SETRACECHECKPOINT = RPC_SCRSETRACECHECKPOINT, + DISABLERACECHECKPOINT = RPC_SCRDISABLERACECHECKPOINT, + GAMEMODERESTART = RPC_SCRGAMEMODERESTART, + PLAYAUDIOSTREAM = RPC_SCRPLAYAUDIOSTREAM, + STOPAUDIOSTREAM = RPC_SCRSTOPAUDIOSTREAM, + REMOVEBUILDINGFORPLAYER = RPC_SCRREMOVEBUILDINGFORPLAYER, + CREATEOBJECT = RPC_SCRCREATEOBJECT, + SETOBJECTPOS = RPC_SCRSETOBJECTPOS, + SETOBJECTROT = RPC_SCRSETOBJECTROT, + DESTROYOBJECT = RPC_SCRDESTROYOBJECT, + DEATHMESSAGE = RPC_SCRDEATHMESSAGE, + SETPLAYERMAPICON = RPC_SCRSETPLAYERMAPICON, + REMOVEVEHICLECOMPONENT = RPC_SCRREMOVEVEHICLECOMPONENT, + CHATBUBBLE = RPC_SCRCHATBUBBLE, + UPDATETIME = RPC_SCRSOMEUPDATE, + SHOWDIALOG = RPC_SCRSHOWDIALOG, + DESTROYPICKUP = RPC_SCRDESTROYPICKUP, + LINKVEHICLETOINTERIOR = RPC_SCRLINKVEHICLETOINTERIOR, + SETPLAYERARMOUR = RPC_SCRSETPLAYERARMOUR, + SETPLAYERARMEDWEAPON = RPC_SCRSETPLAYERARMEDWEAPON, + SETSPAWNINFO = RPC_SCRSETSPAWNINFO, + SETPLAYERTEAM = RPC_SCRSETPLAYERTEAM, + PUTPLAYERINVEHICLE = RPC_SCRPUTPLAYERINVEHICLE, + REMOVEPLAYERFROMVEHICLE = RPC_SCRREMOVEPLAYERFROMVEHICLE, + SETPLAYERCOLOR = RPC_SCRSETPLAYERCOLOR, + DISPLAYGAMETEXT = RPC_SCRDISPLAYGAMETEXT, + FORCECLASSSELECTION = RPC_SCRFORCECLASSSELECTION, + ATTACHOBJECTTOPLAYER = RPC_SCRATTACHOBJECTTOPLAYER, + INITMENU = RPC_SCRINITMENU, + SHOWMENU = RPC_SCRSHOWMENU, + HIDEMENU = RPC_SCRHIDEMENU, + CREATEEXPLOSION = RPC_SCRCREATEEXPLOSION, + SHOWPLAYERNAMETAGFORPLAYER = RPC_SCRSHOWPLAYERNAMETAGFORPLAYER, + ATTACHCAMERATOOBJECT = RPC_SCRATTACHCAMERATOOBJECT, + INTERPOLATECAMERA = RPC_SCRINTERPOLATECAMERA, + SETOBJECTMATERIAL = RPC_SCRSETOBJECTMATERIAL, + GANGZONESTOPFLASH = RPC_SCRGANGZONESTOPFLASH, + APPLYANIMATION = RPC_SCRAPPLYANIMATION, + CLEARANIMATIONS = RPC_SCRCLEARANIMATIONS, + SETPLAYERSPECIALACTION = RPC_SCRSETPLAYERSPECIALACTION, + SETPLAYERFIGHTINGSTYLE = RPC_SCRSETPLAYERFIGHTINGSTYLE, + SETPLAYERVELOCITY = RPC_SCRSETPLAYERVELOCITY, + SETVEHICLEVELOCITY = RPC_SCRSETVEHICLEVELOCITY, + CLIENTMESSAGE = RPC_SCRCLIENTMESSAGE, + SETWORLDTIME = RPC_SCRSETWORLDTIME, + CREATEPICKUP = RPC_SCRCREATEPICKUP, + MOVEOBJECT = RPC_SCRMOVEOBJECT, + ENABLESTUNTBONUSFORPLAYER = RPC_SCRENABLESTUNTBONUSFORPLAYER, + TEXTDRAWSETSTRING = RPC_SCRTEXTDRAWSETSTRING, + SETCHECKPOINT = RPC_SCRSETCHECKPOINT, + GANGZONECREATE = RPC_SCRGANGZONECREATE, + PLAYCRIMEREPORT = RPC_SCRPLAYCRIMEREPORT, + SETPLAYERATTACHEDOBJECT = RPC_SCRSETPLAYERATTACHEDOBJECT, + GANGZONEDESTROY = RPC_SCRGANGZONEDESTROY, + GANGZONEFLASH = RPC_SCRGANGZONEFLASH, + STOPOBJECT = RPC_SCRSTOPOBJECT, + SETNUMBERPLATE = RPC_SCRSETNUMBERPLATE, + TOGGLEPLAYERSPECTATING = RPC_SCRTOGGLEPLAYERSPECTATING, + PLAYERSPECTATEPLAYER = RPC_SCRPLAYERSPECTATEPLAYER, + PLAYERSPECTATEVEHICLE = RPC_SCRPLAYERSPECTATEVEHICLE, + SETPLAYERWANTEDLEVEL = RPC_SCRSETPLAYERWANTEDLEVEL, + SHOWTEXTDRAW = RPC_SCRSHOWTEXTDRAW, + TEXTDRAWHIDEFORPLAYER = RPC_SCRTEXTDRAWHIDEFORPLAYER, + SERVERJOIN = RPC_SCRSERVERJOIN, + SERVERQUIT = RPC_SCRSERVERQUIT, + INITGAME = RPC_SCRINITGAME, + REMOVEPLAYERMAPICON = RPC_SCRREMOVEPLAYERMAPICON, + SETPLAYERAMMO = RPC_SCRSETPLAYERAMMO, + SETGRAVITY = RPC_SCRSETGRAVITY, + SETVEHICLEHEALTH = RPC_SCRSETVEHICLEHEALTH, + ATTACHTRAILERTOVEHICLE = RPC_SCRATTACHTRAILERTOVEHICLE, + DETACHTRAILERFROMVEHICLE = RPC_SCRDETACHTRAILERFROMVEHICLE, + SETWEATHER = RPC_SCRSETWEATHER, + SETPLAYERSKIN = RPC_SCRSETPLAYERSKIN, + SETPLAYERINTERIOR = RPC_SCRSETPLAYERINTERIOR, + SETPLAYERCAMERAPOS = RPC_SCRSETPLAYERCAMERAPOS, + SETPLAYERCAMERALOOKAT = RPC_SCRSETPLAYERCAMERALOOKAT, + SETVEHICLEPOS = RPC_SCRSETVEHICLEPOS, + SETVEHICLEZANGLE = RPC_SCRSETVEHICLEZANGLE, + SETVEHICLEPARAMSFORPLAYER = RPC_SCRSETVEHICLEPARAMSFORPLAYER, + SETCAMERABEHINDPLAYER = RPC_SCRSETCAMERABEHINDPLAYER, + WORLDPLAYERREMOVE = RPC_SCRWORLDPLAYERREMOVE, + WORLDVEHICLEADD = RPC_SCRWORLDVEHICLEADD, + WORLDVEHICLEREMOVE = RPC_SCRWORLDVEHICLEREMOVE, + WORLDPLAYERDEATH = RPC_SCRWORLDPLAYERDEATH, + CREATEACTOR = 171, + DESTROYACTOR = 172, + DESTROY3DTEXTLABEL = 58, + DESTROYWEAPONPICKUP = 151, + TOGGLECAMERATARGET = 170, + SELECTOBJECT = 27, + DISABLEVEHICLECOLLISIONS = 167, + TOGGLEWIDESCREEN = 111, + SETVEHICLETIRES = 98, + SETPLAYERDRUNKVISUALS = 92, + SETPLAYERDRUNKHANDLING = 150, + APPLYACTORANIMATION = 173, + CLEARACTORANIMATION = 174, + SETACTORROTATION = 175, + SETACTORPOSITION = 176, + SETACTORHEALTH = 178, + SETPLAYEROBJECTNOCAMCOL = 169, + + -- Invalid. Retained only for backward compatibility. + ENTEREDITOBJECT = RPC_ENTEREDITOBJECT, + UPDATE3DTEXTLABEL = RPC_SCRUPDATE3DTEXTLABEL, +} + +mod.PACKET = { + VEHICLE_SYNC = PACKET_VEHICLE_SYNC, + RCON_COMMAND = PACKET_RCON_COMMAND, + RCON_RESPONCE = PACKET_RCON_RESPONCE, + AIM_SYNC = PACKET_AIM_SYNC, + WEAPONS_UPDATE = PACKET_WEAPONS_UPDATE, + STATS_UPDATE = PACKET_STATS_UPDATE, + BULLET_SYNC = PACKET_BULLET_SYNC, + PLAYER_SYNC = PACKET_PLAYER_SYNC, + MARKERS_SYNC = PACKET_MARKERS_SYNC, + UNOCCUPIED_SYNC = PACKET_UNOCCUPIED_SYNC, + TRAILER_SYNC = PACKET_TRAILER_SYNC, + PASSENGER_SYNC = PACKET_PASSENGER_SYNC, + SPECTATOR_SYNC = PACKET_SPECTATOR_SYNC, + + INTERNAL_PING = PACKET_INTERNAL_PING, + PING = PACKET_PING, + PING_OPEN_CONNECTIONS = PACKET_PING_OPEN_CONNECTIONS, + CONNECTED_PONG = PACKET_CONNECTED_PONG, + REQUEST_STATIC_DATA = PACKET_REQUEST_STATIC_DATA, + CONNECTION_REQUEST = PACKET_CONNECTION_REQUEST, + AUTHENTICATION = PACKET_AUTH_KEY, + BROADCAST_PINGS = PACKET_BROADCAST_PINGS, + SECURED_CONNECTION_RESPONSE = PACKET_SECURED_CONNECTION_RESPONSE, + SECURED_CONNECTION_CONFIRMATION = PACKET_SECURED_CONNECTION_CONFIRMATION, + RPC_MAPPING = PACKET_RPC_MAPPING, + SET_RANDOM_NUMBER_SEED = PACKET_SET_RANDOM_NUMBER_SEED, + RPC = PACKET_RPC, + RPC_REPLY = PACKET_RPC_REPLY, + DETECT_LOST_CONNECTIONS = PACKET_DETECT_LOST_CONNECTIONS, + OPEN_CONNECTION_REQUEST = PACKET_OPEN_CONNECTION_REQUEST, + OPEN_CONNECTION_REPLY = PACKET_OPEN_CONNECTION_REPLY, + CONNECTION_COOKIE = PACKET_CONNECTION_COOKIE, + RSA_PUBLIC_KEY_MISMATCH = PACKET_RSA_PUBLIC_KEY_MISMATCH, + CONNECTION_ATTEMPT_FAILED = PACKET_CONNECTION_ATTEMPT_FAILED, + NEW_INCOMING_CONNECTION = PACKET_NEW_INCOMING_CONNECTION, + NO_FREE_INCOMING_CONNECTIONS = PACKET_NO_FREE_INCOMING_CONNECTIONS, + DISCONNECTION_NOTIFICATION = PACKET_DISCONNECTION_NOTIFICATION, + CONNECTION_LOST = PACKET_CONNECTION_LOST, + CONNECTION_REQUEST_ACCEPTED = PACKET_CONNECTION_REQUEST_ACCEPTED, + INITIALIZE_ENCRYPTION = PACKET_INITIALIZE_ENCRYPTION, + CONNECTION_BANNED = PACKET_CONNECTION_BANNED, + INVALID_PASSWORD = PACKET_INVALID_PASSWORD, + MODIFIED_PACKET = PACKET_MODIFIED_PACKET, + PONG = PACKET_PONG, + TIMESTAMP = PACKET_TIMESTAMP, + RECEIVED_STATIC_DATA = PACKET_RECEIVED_STATIC_DATA, + REMOTE_DISCONNECTION_NOTIFICATION = PACKET_REMOTE_DISCONNECTION_NOTIFICATION, + REMOTE_CONNECTION_LOST = PACKET_REMOTE_CONNECTION_LOST, + REMOTE_NEW_INCOMING_CONNECTION = PACKET_REMOTE_NEW_INCOMING_CONNECTION, + REMOTE_EXISTING_CONNECTION = PACKET_REMOTE_EXISTING_CONNECTION, + REMOTE_STATIC_DATA = PACKET_REMOTE_STATIC_DATA, + ADVERTISE_SYSTEM = PACKET_ADVERTISE_SYSTEM, + + AUTH_KEY = PACKET_AUTH_KEY, +} + +return mod diff --git a/samp/synchronization.lua b/samp/synchronization.lua new file mode 100644 index 0000000..65c15dc --- /dev/null +++ b/samp/synchronization.lua @@ -0,0 +1,199 @@ +-- This file is part of the SAMP.Lua project. +-- Licensed under the MIT License. +-- Copyright (c) 2016, FYP @ BlastHack Team +-- https://github.com/THE-FYP/SAMP.Lua + +local mod = +{ + MODULEINFO = { + name = 'samp.synchronization', + version = 2 + } +} +local ffi = require 'ffi' + +ffi.cdef[[ +#pragma pack(push, 1) + +typedef struct VectorXYZ { + float x, y, z; +} VectorXYZ; + +typedef struct SampKeys { + uint8_t primaryFire : 1; + uint8_t horn_crouch : 1; + uint8_t secondaryFire_shoot : 1; + uint8_t accel_zoomOut : 1; + uint8_t enterExitCar : 1; + uint8_t decel_jump : 1; + uint8_t circleRight : 1; + uint8_t aim : 1; + uint8_t circleLeft : 1; + uint8_t landingGear_lookback : 1; + uint8_t unknown_walkSlow : 1; + uint8_t specialCtrlUp : 1; + uint8_t specialCtrlDown : 1; + uint8_t specialCtrlLeft : 1; + uint8_t specialCtrlRight : 1; + uint8_t _unknown : 1; +} SampKeys; + +typedef struct PlayerSyncData { + uint16_t leftRightKeys; + uint16_t upDownKeys; + union { + uint16_t keysData; + SampKeys keys; + }; + VectorXYZ position; + float quaternion[4]; + uint8_t health; + uint8_t armor; + uint8_t weapon : 6; + uint8_t specialKey : 2; + uint8_t specialAction; + VectorXYZ moveSpeed; + VectorXYZ surfingOffsets; + uint16_t surfingVehicleId; + union { + struct { + uint16_t id; + uint8_t frameDelta; + union { + struct { + bool loop : 1; + bool lockX : 1; + bool lockY : 1; + bool freeze : 1; + uint8_t time : 2; + uint8_t _unused : 1; + bool regular : 1; + }; + uint8_t value; + } flags; + } animation; + struct { + uint16_t animationId; + uint16_t animationFlags; + }; + }; +} PlayerSyncData; + +typedef struct VehicleSyncData { + uint16_t vehicleId; + uint16_t leftRightKeys; + uint16_t upDownKeys; + union { + uint16_t keysData; + SampKeys keys; + }; + float quaternion[4]; + VectorXYZ position; + VectorXYZ moveSpeed; + float vehicleHealth; + uint8_t playerHealth; + uint8_t armor; + uint8_t currentWeapon : 6; + uint8_t specialKey : 2; + uint8_t siren; + uint8_t landingGearState; + uint16_t trailerId; + union { + float bikeLean; + float trainSpeed; + uint16_t hydraThrustAngle[2]; + }; +} VehicleSyncData; + +typedef struct PassengerSyncData { + uint16_t vehicleId; + uint8_t seatId : 6; + bool driveBy : 1; + bool cuffed : 1; + uint8_t currentWeapon : 6; + uint8_t specialKey : 2; + uint8_t health; + uint8_t armor; + uint16_t leftRightKeys; + uint16_t upDownKeys; + union { + uint16_t keysData; + SampKeys keys; + }; + VectorXYZ position; +} PassengerSyncData; + +typedef struct UnoccupiedSyncData { + uint16_t vehicleId; + uint8_t seatId; + VectorXYZ roll; + VectorXYZ direction; + VectorXYZ position; + VectorXYZ moveSpeed; + VectorXYZ turnSpeed; + float vehicleHealth; +} UnoccupiedSyncData; + +typedef struct TrailerSyncData { + uint16_t trailerId; + VectorXYZ position; + union { + struct { + float quaternion[4]; + VectorXYZ moveSpeed; + VectorXYZ turnSpeed; + }; + /* Invalid. Retained for backwards compatibility. */ + struct { + VectorXYZ roll; + VectorXYZ direction; + VectorXYZ speed; + uint32_t unk; + }; + }; +} TrailerSyncData; + +typedef struct SpectatorSyncData { + uint16_t leftRightKeys; + uint16_t upDownKeys; + union { + uint16_t keysData; + SampKeys keys; + }; + VectorXYZ position; +} SpectatorSyncData; + +typedef struct BulletSyncData { + uint8_t targetType; + uint16_t targetId; + VectorXYZ origin; + VectorXYZ target; + VectorXYZ center; + uint8_t weaponId; +} BulletSyncData; + +typedef struct AimSyncData { + uint8_t camMode; + VectorXYZ camFront; + VectorXYZ camPos; + float aimZ; + uint8_t camExtZoom : 6; + uint8_t weaponState : 2; + uint8_t aspectRatio; +} AimSyncData; + +#pragma pack(pop) +]] + +assert(ffi.sizeof('VectorXYZ') == 12) +assert(ffi.sizeof('SampKeys') == 2) +assert(ffi.sizeof('PlayerSyncData') == 68) +assert(ffi.sizeof('VehicleSyncData') == 63) +assert(ffi.sizeof('PassengerSyncData') == 24) +assert(ffi.sizeof('UnoccupiedSyncData') == 67) +assert(ffi.sizeof('TrailerSyncData') == 54) +assert(ffi.sizeof('SpectatorSyncData') == 18) +assert(ffi.sizeof('BulletSyncData') == 40) +assert(ffi.sizeof('AimSyncData') == 31) + +return mod diff --git a/socket.lua b/socket.lua new file mode 100644 index 0000000..3913e6f --- /dev/null +++ b/socket.lua @@ -0,0 +1,149 @@ +----------------------------------------------------------------------------- +-- LuaSocket helper module +-- Author: Diego Nehab +----------------------------------------------------------------------------- + +----------------------------------------------------------------------------- +-- Declare module and import dependencies +----------------------------------------------------------------------------- +local base = _G +local string = require("string") +local math = require("math") +local socket = require("socket.core") + +local _M = socket + +----------------------------------------------------------------------------- +-- Exported auxiliar functions +----------------------------------------------------------------------------- +function _M.connect4(address, port, laddress, lport) + return socket.connect(address, port, laddress, lport, "inet") +end + +function _M.connect6(address, port, laddress, lport) + return socket.connect(address, port, laddress, lport, "inet6") +end + +function _M.bind(host, port, backlog) + if host == "*" then host = "0.0.0.0" end + local addrinfo, err = socket.dns.getaddrinfo(host); + if not addrinfo then return nil, err end + local sock, res + err = "no info on address" + for i, alt in base.ipairs(addrinfo) do + if alt.family == "inet" then + sock, err = socket.tcp() + else + sock, err = socket.tcp6() + end + if not sock then return nil, err end + sock:setoption("reuseaddr", true) + res, err = sock:bind(alt.addr, port) + if not res then + sock:close() + else + res, err = sock:listen(backlog) + if not res then + sock:close() + else + return sock + end + end + end + return nil, err +end + +_M.try = _M.newtry() + +function _M.choose(table) + return function(name, opt1, opt2) + if base.type(name) ~= "string" then + name, opt1, opt2 = "default", name, opt1 + end + local f = table[name or "nil"] + if not f then base.error("unknown key (".. base.tostring(name) ..")", 3) + else return f(opt1, opt2) end + end +end + +----------------------------------------------------------------------------- +-- Socket sources and sinks, conforming to LTN12 +----------------------------------------------------------------------------- +-- create namespaces inside LuaSocket namespace +local sourcet, sinkt = {}, {} +_M.sourcet = sourcet +_M.sinkt = sinkt + +_M.BLOCKSIZE = 2048 + +sinkt["close-when-done"] = function(sock) + return base.setmetatable({ + getfd = function() return sock:getfd() end, + dirty = function() return sock:dirty() end + }, { + __call = function(self, chunk, err) + if not chunk then + sock:close() + return 1 + else return sock:send(chunk) end + end + }) +end + +sinkt["keep-open"] = function(sock) + return base.setmetatable({ + getfd = function() return sock:getfd() end, + dirty = function() return sock:dirty() end + }, { + __call = function(self, chunk, err) + if chunk then return sock:send(chunk) + else return 1 end + end + }) +end + +sinkt["default"] = sinkt["keep-open"] + +_M.sink = _M.choose(sinkt) + +sourcet["by-length"] = function(sock, length) + return base.setmetatable({ + getfd = function() return sock:getfd() end, + dirty = function() return sock:dirty() end + }, { + __call = function() + if length <= 0 then return nil end + local size = math.min(socket.BLOCKSIZE, length) + local chunk, err = sock:receive(size) + if err then return nil, err end + length = length - string.len(chunk) + return chunk + end + }) +end + +sourcet["until-closed"] = function(sock) + local done + return base.setmetatable({ + getfd = function() return sock:getfd() end, + dirty = function() return sock:dirty() end + }, { + __call = function() + if done then return nil end + local chunk, err, partial = sock:receive(socket.BLOCKSIZE) + if not err then return chunk + elseif err == "closed" then + sock:close() + done = 1 + return partial + else return nil, err end + end + }) +end + + +sourcet["default"] = sourcet["until-closed"] + +_M.source = _M.choose(sourcet) + +return _M diff --git a/socket/core.dll b/socket/core.dll new file mode 100644 index 0000000..67cd119 Binary files /dev/null and b/socket/core.dll differ diff --git a/socket/ftp.lua b/socket/ftp.lua new file mode 100644 index 0000000..ea1145b --- /dev/null +++ b/socket/ftp.lua @@ -0,0 +1,285 @@ +----------------------------------------------------------------------------- +-- FTP support for the Lua language +-- LuaSocket toolkit. +-- Author: Diego Nehab +----------------------------------------------------------------------------- + +----------------------------------------------------------------------------- +-- Declare module and import dependencies +----------------------------------------------------------------------------- +local base = _G +local table = require("table") +local string = require("string") +local math = require("math") +local socket = require("socket") +local url = require("socket.url") +local tp = require("socket.tp") +local ltn12 = require("ltn12") +socket.ftp = {} +local _M = socket.ftp +----------------------------------------------------------------------------- +-- Program constants +----------------------------------------------------------------------------- +-- timeout in seconds before the program gives up on a connection +_M.TIMEOUT = 60 +-- default port for ftp service +_M.PORT = 21 +-- this is the default anonymous password. used when no password is +-- provided in url. should be changed to your e-mail. +_M.USER = "ftp" +_M.PASSWORD = "anonymous@anonymous.org" + +----------------------------------------------------------------------------- +-- Low level FTP API +----------------------------------------------------------------------------- +local metat = { __index = {} } + +function _M.open(server, port, create) + local tp = socket.try(tp.connect(server, port or _M.PORT, _M.TIMEOUT, create)) + local f = base.setmetatable({ tp = tp }, metat) + -- make sure everything gets closed in an exception + f.try = socket.newtry(function() f:close() end) + return f +end + +function metat.__index:portconnect() + self.try(self.server:settimeout(_M.TIMEOUT)) + self.data = self.try(self.server:accept()) + self.try(self.data:settimeout(_M.TIMEOUT)) +end + +function metat.__index:pasvconnect() + self.data = self.try(socket.tcp()) + self.try(self.data:settimeout(_M.TIMEOUT)) + self.try(self.data:connect(self.pasvt.ip, self.pasvt.port)) +end + +function metat.__index:login(user, password) + self.try(self.tp:command("user", user or _M.USER)) + local code, reply = self.try(self.tp:check{"2..", 331}) + if code == 331 then + self.try(self.tp:command("pass", password or _M.PASSWORD)) + self.try(self.tp:check("2..")) + end + return 1 +end + +function metat.__index:pasv() + self.try(self.tp:command("pasv")) + local code, reply = self.try(self.tp:check("2..")) + local pattern = "(%d+)%D(%d+)%D(%d+)%D(%d+)%D(%d+)%D(%d+)" + local a, b, c, d, p1, p2 = socket.skip(2, string.find(reply, pattern)) + self.try(a and b and c and d and p1 and p2, reply) + self.pasvt = { + ip = string.format("%d.%d.%d.%d", a, b, c, d), + port = p1*256 + p2 + } + if self.server then + self.server:close() + self.server = nil + end + return self.pasvt.ip, self.pasvt.port +end + +function metat.__index:port(ip, port) + self.pasvt = nil + if not ip then + ip, port = self.try(self.tp:getcontrol():getsockname()) + self.server = self.try(socket.bind(ip, 0)) + ip, port = self.try(self.server:getsockname()) + self.try(self.server:settimeout(_M.TIMEOUT)) + end + local pl = math.mod(port, 256) + local ph = (port - pl)/256 + local arg = string.gsub(string.format("%s,%d,%d", ip, ph, pl), "%.", ",") + self.try(self.tp:command("port", arg)) + self.try(self.tp:check("2..")) + return 1 +end + +function metat.__index:send(sendt) + self.try(self.pasvt or self.server, "need port or pasv first") + -- if there is a pasvt table, we already sent a PASV command + -- we just get the data connection into self.data + if self.pasvt then self:pasvconnect() end + -- get the transfer argument and command + local argument = sendt.argument or + url.unescape(string.gsub(sendt.path or "", "^[/\\]", "")) + if argument == "" then argument = nil end + local command = sendt.command or "stor" + -- send the transfer command and check the reply + self.try(self.tp:command(command, argument)) + local code, reply = self.try(self.tp:check{"2..", "1.."}) + -- if there is not a a pasvt table, then there is a server + -- and we already sent a PORT command + if not self.pasvt then self:portconnect() end + -- get the sink, source and step for the transfer + local step = sendt.step or ltn12.pump.step + local readt = {self.tp.c} + local checkstep = function(src, snk) + -- check status in control connection while downloading + local readyt = socket.select(readt, nil, 0) + if readyt[tp] then code = self.try(self.tp:check("2..")) end + return step(src, snk) + end + local sink = socket.sink("close-when-done", self.data) + -- transfer all data and check error + self.try(ltn12.pump.all(sendt.source, sink, checkstep)) + if string.find(code, "1..") then self.try(self.tp:check("2..")) end + -- done with data connection + self.data:close() + -- find out how many bytes were sent + local sent = socket.skip(1, self.data:getstats()) + self.data = nil + return sent +end + +function metat.__index:receive(recvt) + self.try(self.pasvt or self.server, "need port or pasv first") + if self.pasvt then self:pasvconnect() end + local argument = recvt.argument or + url.unescape(string.gsub(recvt.path or "", "^[/\\]", "")) + if argument == "" then argument = nil end + local command = recvt.command or "retr" + self.try(self.tp:command(command, argument)) + local code,reply = self.try(self.tp:check{"1..", "2.."}) + if (code >= 200) and (code <= 299) then + recvt.sink(reply) + return 1 + end + if not self.pasvt then self:portconnect() end + local source = socket.source("until-closed", self.data) + local step = recvt.step or ltn12.pump.step + self.try(ltn12.pump.all(source, recvt.sink, step)) + if string.find(code, "1..") then self.try(self.tp:check("2..")) end + self.data:close() + self.data = nil + return 1 +end + +function metat.__index:cwd(dir) + self.try(self.tp:command("cwd", dir)) + self.try(self.tp:check(250)) + return 1 +end + +function metat.__index:type(type) + self.try(self.tp:command("type", type)) + self.try(self.tp:check(200)) + return 1 +end + +function metat.__index:greet() + local code = self.try(self.tp:check{"1..", "2.."}) + if string.find(code, "1..") then self.try(self.tp:check("2..")) end + return 1 +end + +function metat.__index:quit() + self.try(self.tp:command("quit")) + self.try(self.tp:check("2..")) + return 1 +end + +function metat.__index:close() + if self.data then self.data:close() end + if self.server then self.server:close() end + return self.tp:close() +end + +----------------------------------------------------------------------------- +-- High level FTP API +----------------------------------------------------------------------------- +local function override(t) + if t.url then + local u = url.parse(t.url) + for i,v in base.pairs(t) do + u[i] = v + end + return u + else return t end +end + +local function tput(putt) + putt = override(putt) + socket.try(putt.host, "missing hostname") + local f = _M.open(putt.host, putt.port, putt.create) + f:greet() + f:login(putt.user, putt.password) + if putt.type then f:type(putt.type) end + f:pasv() + local sent = f:send(putt) + f:quit() + f:close() + return sent +end + +local default = { + path = "/", + scheme = "ftp" +} + +local function parse(u) + local t = socket.try(url.parse(u, default)) + socket.try(t.scheme == "ftp", "wrong scheme '" .. t.scheme .. "'") + socket.try(t.host, "missing hostname") + local pat = "^type=(.)$" + if t.params then + t.type = socket.skip(2, string.find(t.params, pat)) + socket.try(t.type == "a" or t.type == "i", + "invalid type '" .. t.type .. "'") + end + return t +end + +local function sput(u, body) + local putt = parse(u) + putt.source = ltn12.source.string(body) + return tput(putt) +end + +_M.put = socket.protect(function(putt, body) + if base.type(putt) == "string" then return sput(putt, body) + else return tput(putt) end +end) + +local function tget(gett) + gett = override(gett) + socket.try(gett.host, "missing hostname") + local f = _M.open(gett.host, gett.port, gett.create) + f:greet() + f:login(gett.user, gett.password) + if gett.type then f:type(gett.type) end + f:pasv() + f:receive(gett) + f:quit() + return f:close() +end + +local function sget(u) + local gett = parse(u) + local t = {} + gett.sink = ltn12.sink.table(t) + tget(gett) + return table.concat(t) +end + +_M.command = socket.protect(function(cmdt) + cmdt = override(cmdt) + socket.try(cmdt.host, "missing hostname") + socket.try(cmdt.command, "missing command") + local f = open(cmdt.host, cmdt.port, cmdt.create) + f:greet() + f:login(cmdt.user, cmdt.password) + f.try(f.tp:command(cmdt.command, cmdt.argument)) + if cmdt.check then f.try(f.tp:check(cmdt.check)) end + f:quit() + return f:close() +end) + +_M.get = socket.protect(function(gett) + if base.type(gett) == "string" then return sget(gett) + else return tget(gett) end +end) + +return _M \ No newline at end of file diff --git a/socket/headers.lua b/socket/headers.lua new file mode 100644 index 0000000..1eb8223 --- /dev/null +++ b/socket/headers.lua @@ -0,0 +1,104 @@ +----------------------------------------------------------------------------- +-- Canonic header field capitalization +-- LuaSocket toolkit. +-- Author: Diego Nehab +----------------------------------------------------------------------------- +local socket = require("socket") +socket.headers = {} +local _M = socket.headers + +_M.canonic = { + ["accept"] = "Accept", + ["accept-charset"] = "Accept-Charset", + ["accept-encoding"] = "Accept-Encoding", + ["accept-language"] = "Accept-Language", + ["accept-ranges"] = "Accept-Ranges", + ["action"] = "Action", + ["alternate-recipient"] = "Alternate-Recipient", + ["age"] = "Age", + ["allow"] = "Allow", + ["arrival-date"] = "Arrival-Date", + ["authorization"] = "Authorization", + ["bcc"] = "Bcc", + ["cache-control"] = "Cache-Control", + ["cc"] = "Cc", + ["comments"] = "Comments", + ["connection"] = "Connection", + ["content-description"] = "Content-Description", + ["content-disposition"] = "Content-Disposition", + ["content-encoding"] = "Content-Encoding", + ["content-id"] = "Content-ID", + ["content-language"] = "Content-Language", + ["content-length"] = "Content-Length", + ["content-location"] = "Content-Location", + ["content-md5"] = "Content-MD5", + ["content-range"] = "Content-Range", + ["content-transfer-encoding"] = "Content-Transfer-Encoding", + ["content-type"] = "Content-Type", + ["cookie"] = "Cookie", + ["date"] = "Date", + ["diagnostic-code"] = "Diagnostic-Code", + ["dsn-gateway"] = "DSN-Gateway", + ["etag"] = "ETag", + ["expect"] = "Expect", + ["expires"] = "Expires", + ["final-log-id"] = "Final-Log-ID", + ["final-recipient"] = "Final-Recipient", + ["from"] = "From", + ["host"] = "Host", + ["if-match"] = "If-Match", + ["if-modified-since"] = "If-Modified-Since", + ["if-none-match"] = "If-None-Match", + ["if-range"] = "If-Range", + ["if-unmodified-since"] = "If-Unmodified-Since", + ["in-reply-to"] = "In-Reply-To", + ["keywords"] = "Keywords", + ["last-attempt-date"] = "Last-Attempt-Date", + ["last-modified"] = "Last-Modified", + ["location"] = "Location", + ["max-forwards"] = "Max-Forwards", + ["message-id"] = "Message-ID", + ["mime-version"] = "MIME-Version", + ["original-envelope-id"] = "Original-Envelope-ID", + ["original-recipient"] = "Original-Recipient", + ["pragma"] = "Pragma", + ["proxy-authenticate"] = "Proxy-Authenticate", + ["proxy-authorization"] = "Proxy-Authorization", + ["range"] = "Range", + ["received"] = "Received", + ["received-from-mta"] = "Received-From-MTA", + ["references"] = "References", + ["referer"] = "Referer", + ["remote-mta"] = "Remote-MTA", + ["reply-to"] = "Reply-To", + ["reporting-mta"] = "Reporting-MTA", + ["resent-bcc"] = "Resent-Bcc", + ["resent-cc"] = "Resent-Cc", + ["resent-date"] = "Resent-Date", + ["resent-from"] = "Resent-From", + ["resent-message-id"] = "Resent-Message-ID", + ["resent-reply-to"] = "Resent-Reply-To", + ["resent-sender"] = "Resent-Sender", + ["resent-to"] = "Resent-To", + ["retry-after"] = "Retry-After", + ["return-path"] = "Return-Path", + ["sender"] = "Sender", + ["server"] = "Server", + ["smtp-remote-recipient"] = "SMTP-Remote-Recipient", + ["status"] = "Status", + ["subject"] = "Subject", + ["te"] = "TE", + ["to"] = "To", + ["trailer"] = "Trailer", + ["transfer-encoding"] = "Transfer-Encoding", + ["upgrade"] = "Upgrade", + ["user-agent"] = "User-Agent", + ["vary"] = "Vary", + ["via"] = "Via", + ["warning"] = "Warning", + ["will-retry-until"] = "Will-Retry-Until", + ["www-authenticate"] = "WWW-Authenticate", + ["x-mailer"] = "X-Mailer", +} + +return _M \ No newline at end of file diff --git a/socket/http.lua b/socket/http.lua new file mode 100644 index 0000000..ac4b2d6 --- /dev/null +++ b/socket/http.lua @@ -0,0 +1,354 @@ +----------------------------------------------------------------------------- +-- HTTP/1.1 client support for the Lua language. +-- LuaSocket toolkit. +-- Author: Diego Nehab +----------------------------------------------------------------------------- + +----------------------------------------------------------------------------- +-- Declare module and import dependencies +------------------------------------------------------------------------------- +local socket = require("socket") +local url = require("socket.url") +local ltn12 = require("ltn12") +local mime = require("mime") +local string = require("string") +local headers = require("socket.headers") +local base = _G +local table = require("table") +socket.http = {} +local _M = socket.http + +----------------------------------------------------------------------------- +-- Program constants +----------------------------------------------------------------------------- +-- connection timeout in seconds +TIMEOUT = 60 +-- default port for document retrieval +_M.PORT = 80 +-- user agent field sent in request +_M.USERAGENT = socket._VERSION + +----------------------------------------------------------------------------- +-- Reads MIME headers from a connection, unfolding where needed +----------------------------------------------------------------------------- +local function receiveheaders(sock, headers) + local line, name, value, err + headers = headers or {} + -- get first line + line, err = sock:receive() + if err then return nil, err end + -- headers go until a blank line is found + while line ~= "" do + -- get field-name and value + name, value = socket.skip(2, string.find(line, "^(.-):%s*(.*)")) + if not (name and value) then return nil, "malformed reponse headers" end + name = string.lower(name) + -- get next line (value might be folded) + line, err = sock:receive() + if err then return nil, err end + -- unfold any folded values + while string.find(line, "^%s") do + value = value .. line + line = sock:receive() + if err then return nil, err end + end + -- save pair in table + if headers[name] then headers[name] = headers[name] .. ", " .. value + else headers[name] = value end + end + return headers +end + +----------------------------------------------------------------------------- +-- Extra sources and sinks +----------------------------------------------------------------------------- +socket.sourcet["http-chunked"] = function(sock, headers) + return base.setmetatable({ + getfd = function() return sock:getfd() end, + dirty = function() return sock:dirty() end + }, { + __call = function() + -- get chunk size, skip extention + local line, err = sock:receive() + if err then return nil, err end + local size = base.tonumber(string.gsub(line, ";.*", ""), 16) + if not size then return nil, "invalid chunk size" end + -- was it the last chunk? + if size > 0 then + -- if not, get chunk and skip terminating CRLF + local chunk, err, part = sock:receive(size) + if chunk then sock:receive() end + return chunk, err + else + -- if it was, read trailers into headers table + headers, err = receiveheaders(sock, headers) + if not headers then return nil, err end + end + end + }) +end + +socket.sinkt["http-chunked"] = function(sock) + return base.setmetatable({ + getfd = function() return sock:getfd() end, + dirty = function() return sock:dirty() end + }, { + __call = function(self, chunk, err) + if not chunk then return sock:send("0\r\n\r\n") end + local size = string.format("%X\r\n", string.len(chunk)) + return sock:send(size .. chunk .. "\r\n") + end + }) +end + +----------------------------------------------------------------------------- +-- Low level HTTP API +----------------------------------------------------------------------------- +local metat = { __index = {} } + +function _M.open(host, port, create) + -- create socket with user connect function, or with default + local c = socket.try((create or socket.tcp)()) + local h = base.setmetatable({ c = c }, metat) + -- create finalized try + h.try = socket.newtry(function() h:close() end) + -- set timeout before connecting + h.try(c:settimeout(_M.TIMEOUT)) + h.try(c:connect(host, port or _M.PORT)) + -- here everything worked + return h +end + +function metat.__index:sendrequestline(method, uri) + local reqline = string.format("%s %s HTTP/1.1\r\n", method or "GET", uri) + return self.try(self.c:send(reqline)) +end + +function metat.__index:sendheaders(tosend) + local canonic = headers.canonic + local h = "\r\n" + for f, v in base.pairs(tosend) do + h = (canonic[f] or f) .. ": " .. v .. "\r\n" .. h + end + self.try(self.c:send(h)) + return 1 +end + +function metat.__index:sendbody(headers, source, step) + source = source or ltn12.source.empty() + step = step or ltn12.pump.step + -- if we don't know the size in advance, send chunked and hope for the best + local mode = "http-chunked" + if headers["content-length"] then mode = "keep-open" end + return self.try(ltn12.pump.all(source, socket.sink(mode, self.c), step)) +end + +function metat.__index:receivestatusline() + local status = self.try(self.c:receive(5)) + -- identify HTTP/0.9 responses, which do not contain a status line + -- this is just a heuristic, but is what the RFC recommends + if status ~= "HTTP/" then return nil, status end + -- otherwise proceed reading a status line + status = self.try(self.c:receive("*l", status)) + local code = socket.skip(2, string.find(status, "HTTP/%d*%.%d* (%d%d%d)")) + return self.try(base.tonumber(code), status) +end + +function metat.__index:receiveheaders() + return self.try(receiveheaders(self.c)) +end + +function metat.__index:receivebody(headers, sink, step) + sink = sink or ltn12.sink.null() + step = step or ltn12.pump.step + local length = base.tonumber(headers["content-length"]) + local t = headers["transfer-encoding"] -- shortcut + local mode = "default" -- connection close + if t and t ~= "identity" then mode = "http-chunked" + elseif base.tonumber(headers["content-length"]) then mode = "by-length" end + return self.try(ltn12.pump.all(socket.source(mode, self.c, length), + sink, step)) +end + +function metat.__index:receive09body(status, sink, step) + local source = ltn12.source.rewind(socket.source("until-closed", self.c)) + source(status) + return self.try(ltn12.pump.all(source, sink, step)) +end + +function metat.__index:close() + return self.c:close() +end + +----------------------------------------------------------------------------- +-- High level HTTP API +----------------------------------------------------------------------------- +local function adjusturi(reqt) + local u = reqt + -- if there is a proxy, we need the full url. otherwise, just a part. + if not reqt.proxy and not PROXY then + u = { + path = socket.try(reqt.path, "invalid path 'nil'"), + params = reqt.params, + query = reqt.query, + fragment = reqt.fragment + } + end + return url.build(u) +end + +local function adjustproxy(reqt) + local proxy = reqt.proxy or PROXY + if proxy then + proxy = url.parse(proxy) + return proxy.host, proxy.port or 3128 + else + return reqt.host, reqt.port + end +end + +local function adjustheaders(reqt) + -- default headers + local lower = { + ["user-agent"] = _M.USERAGENT, + ["host"] = reqt.host, + ["connection"] = "close, TE", + ["te"] = "trailers" + } + -- if we have authentication information, pass it along + if reqt.user and reqt.password then + lower["authorization"] = + "Basic " .. (mime.b64(reqt.user .. ":" .. reqt.password)) + end + -- override with user headers + for i,v in base.pairs(reqt.headers or lower) do + lower[string.lower(i)] = v + end + return lower +end + +-- default url parts +local default = { + host = "", + port = _M.PORT, + path ="/", + scheme = "http" +} + +local function adjustrequest(reqt) + -- parse url if provided + local nreqt = reqt.url and url.parse(reqt.url, default) or {} + -- explicit components override url + for i,v in base.pairs(reqt) do nreqt[i] = v end + if nreqt.port == "" then nreqt.port = 80 end + socket.try(nreqt.host and nreqt.host ~= "", + "invalid host '" .. base.tostring(nreqt.host) .. "'") + -- compute uri if user hasn't overriden + nreqt.uri = reqt.uri or adjusturi(nreqt) + -- ajust host and port if there is a proxy + nreqt.host, nreqt.port = adjustproxy(nreqt) + -- adjust headers in request + nreqt.headers = adjustheaders(nreqt) + return nreqt +end + +local function shouldredirect(reqt, code, headers) + return headers.location and + string.gsub(headers.location, "%s", "") ~= "" and + (reqt.redirect ~= false) and + (code == 301 or code == 302 or code == 303 or code == 307) and + (not reqt.method or reqt.method == "GET" or reqt.method == "HEAD") + and (not reqt.nredirects or reqt.nredirects < 5) +end + +local function shouldreceivebody(reqt, code) + if reqt.method == "HEAD" then return nil end + if code == 204 or code == 304 then return nil end + if code >= 100 and code < 200 then return nil end + return 1 +end + +-- forward declarations +local trequest, tredirect + +--[[local]] function tredirect(reqt, location) + local result, code, headers, status = trequest { + -- the RFC says the redirect URL has to be absolute, but some + -- servers do not respect that + url = url.absolute(reqt.url, location), + source = reqt.source, + sink = reqt.sink, + headers = reqt.headers, + proxy = reqt.proxy, + nredirects = (reqt.nredirects or 0) + 1, + create = reqt.create + } + -- pass location header back as a hint we redirected + headers = headers or {} + headers.location = headers.location or location + return result, code, headers, status +end + +--[[local]] function trequest(reqt) + -- we loop until we get what we want, or + -- until we are sure there is no way to get it + local nreqt = adjustrequest(reqt) + local h = _M.open(nreqt.host, nreqt.port, nreqt.create) + -- send request line and headers + h:sendrequestline(nreqt.method, nreqt.uri) + h:sendheaders(nreqt.headers) + -- if there is a body, send it + if nreqt.source then + h:sendbody(nreqt.headers, nreqt.source, nreqt.step) + end + local code, status = h:receivestatusline() + -- if it is an HTTP/0.9 server, simply get the body and we are done + if not code then + h:receive09body(status, nreqt.sink, nreqt.step) + return 1, 200 + end + local headers + -- ignore any 100-continue messages + while code == 100 do + headers = h:receiveheaders() + code, status = h:receivestatusline() + end + headers = h:receiveheaders() + -- at this point we should have a honest reply from the server + -- we can't redirect if we already used the source, so we report the error + if shouldredirect(nreqt, code, headers) and not nreqt.source then + h:close() + return tredirect(reqt, headers.location) + end + -- here we are finally done + if shouldreceivebody(nreqt, code) then + h:receivebody(headers, nreqt.sink, nreqt.step) + end + h:close() + return 1, code, headers, status +end + +local function srequest(u, b) + local t = {} + local reqt = { + url = u, + sink = ltn12.sink.table(t) + } + if b then + reqt.source = ltn12.source.string(b) + reqt.headers = { + ["content-length"] = string.len(b), + ["content-type"] = "application/x-www-form-urlencoded" + } + reqt.method = "POST" + end + local code, headers, status = socket.skip(1, trequest(reqt)) + return table.concat(t), code, headers, status +end + +_M.request = socket.protect(function(reqt, body) + if base.type(reqt) == "string" then return srequest(reqt, body) + else return trequest(reqt) end +end) + +return _M \ No newline at end of file diff --git a/socket/smtp.lua b/socket/smtp.lua new file mode 100644 index 0000000..b113d00 --- /dev/null +++ b/socket/smtp.lua @@ -0,0 +1,256 @@ +----------------------------------------------------------------------------- +-- SMTP client support for the Lua language. +-- LuaSocket toolkit. +-- Author: Diego Nehab +----------------------------------------------------------------------------- + +----------------------------------------------------------------------------- +-- Declare module and import dependencies +----------------------------------------------------------------------------- +local base = _G +local coroutine = require("coroutine") +local string = require("string") +local math = require("math") +local os = require("os") +local socket = require("socket") +local tp = require("socket.tp") +local ltn12 = require("ltn12") +local headers = require("socket.headers") +local mime = require("mime") + +socket.smtp = {} +local _M = socket.smtp + +----------------------------------------------------------------------------- +-- Program constants +----------------------------------------------------------------------------- +-- timeout for connection +_M.TIMEOUT = 60 +-- default server used to send e-mails +_M.SERVER = "localhost" +-- default port +_M.PORT = 25 +-- domain used in HELO command and default sendmail +-- If we are under a CGI, try to get from environment +_M.DOMAIN = os.getenv("SERVER_NAME") or "localhost" +-- default time zone (means we don't know) +_M.ZONE = "-0000" + +--------------------------------------------------------------------------- +-- Low level SMTP API +----------------------------------------------------------------------------- +local metat = { __index = {} } + +function metat.__index:greet(domain) + self.try(self.tp:check("2..")) + self.try(self.tp:command("EHLO", domain or _M.DOMAIN)) + return socket.skip(1, self.try(self.tp:check("2.."))) +end + +function metat.__index:mail(from) + self.try(self.tp:command("MAIL", "FROM:" .. from)) + return self.try(self.tp:check("2..")) +end + +function metat.__index:rcpt(to) + self.try(self.tp:command("RCPT", "TO:" .. to)) + return self.try(self.tp:check("2..")) +end + +function metat.__index:data(src, step) + self.try(self.tp:command("DATA")) + self.try(self.tp:check("3..")) + self.try(self.tp:source(src, step)) + self.try(self.tp:send("\r\n.\r\n")) + return self.try(self.tp:check("2..")) +end + +function metat.__index:quit() + self.try(self.tp:command("QUIT")) + return self.try(self.tp:check("2..")) +end + +function metat.__index:close() + return self.tp:close() +end + +function metat.__index:login(user, password) + self.try(self.tp:command("AUTH", "LOGIN")) + self.try(self.tp:check("3..")) + self.try(self.tp:send(mime.b64(user) .. "\r\n")) + self.try(self.tp:check("3..")) + self.try(self.tp:send(mime.b64(password) .. "\r\n")) + return self.try(self.tp:check("2..")) +end + +function metat.__index:plain(user, password) + local auth = "PLAIN " .. mime.b64("\0" .. user .. "\0" .. password) + self.try(self.tp:command("AUTH", auth)) + return self.try(self.tp:check("2..")) +end + +function metat.__index:auth(user, password, ext) + if not user or not password then return 1 end + if string.find(ext, "AUTH[^\n]+LOGIN") then + return self:login(user, password) + elseif string.find(ext, "AUTH[^\n]+PLAIN") then + return self:plain(user, password) + else + self.try(nil, "authentication not supported") + end +end + +-- send message or throw an exception +function metat.__index:send(mailt) + self:mail(mailt.from) + if base.type(mailt.rcpt) == "table" then + for i,v in base.ipairs(mailt.rcpt) do + self:rcpt(v) + end + else + self:rcpt(mailt.rcpt) + end + self:data(ltn12.source.chain(mailt.source, mime.stuff()), mailt.step) +end + +function _M.open(server, port, create) + local tp = socket.try(tp.connect(server or _M.SERVER, port or _M.PORT, + _M.TIMEOUT, create)) + local s = base.setmetatable({tp = tp}, metat) + -- make sure tp is closed if we get an exception + s.try = socket.newtry(function() + s:close() + end) + return s +end + +-- convert headers to lowercase +local function lower_headers(headers) + local lower = {} + for i,v in base.pairs(headers or lower) do + lower[string.lower(i)] = v + end + return lower +end + +--------------------------------------------------------------------------- +-- Multipart message source +----------------------------------------------------------------------------- +-- returns a hopefully unique mime boundary +local seqno = 0 +local function newboundary() + seqno = seqno + 1 + return string.format('%s%05d==%05u', os.date('%d%m%Y%H%M%S'), + math.random(0, 99999), seqno) +end + +-- send_message forward declaration +local send_message + +-- yield the headers all at once, it's faster +local function send_headers(tosend) + local canonic = headers.canonic + local h = "\r\n" + for f,v in base.pairs(tosend) do + h = (canonic[f] or f) .. ': ' .. v .. "\r\n" .. h + end + coroutine.yield(h) +end + +-- yield multipart message body from a multipart message table +local function send_multipart(mesgt) + -- make sure we have our boundary and send headers + local bd = newboundary() + local headers = lower_headers(mesgt.headers or {}) + headers['content-type'] = headers['content-type'] or 'multipart/mixed' + headers['content-type'] = headers['content-type'] .. + '; boundary="' .. bd .. '"' + send_headers(headers) + -- send preamble + if mesgt.body.preamble then + coroutine.yield(mesgt.body.preamble) + coroutine.yield("\r\n") + end + -- send each part separated by a boundary + for i, m in base.ipairs(mesgt.body) do + coroutine.yield("\r\n--" .. bd .. "\r\n") + send_message(m) + end + -- send last boundary + coroutine.yield("\r\n--" .. bd .. "--\r\n\r\n") + -- send epilogue + if mesgt.body.epilogue then + coroutine.yield(mesgt.body.epilogue) + coroutine.yield("\r\n") + end +end + +-- yield message body from a source +local function send_source(mesgt) + -- make sure we have a content-type + local headers = lower_headers(mesgt.headers or {}) + headers['content-type'] = headers['content-type'] or + 'text/plain; charset="iso-8859-1"' + send_headers(headers) + -- send body from source + while true do + local chunk, err = mesgt.body() + if err then coroutine.yield(nil, err) + elseif chunk then coroutine.yield(chunk) + else break end + end +end + +-- yield message body from a string +local function send_string(mesgt) + -- make sure we have a content-type + local headers = lower_headers(mesgt.headers or {}) + headers['content-type'] = headers['content-type'] or + 'text/plain; charset="iso-8859-1"' + send_headers(headers) + -- send body from string + coroutine.yield(mesgt.body) +end + +-- message source +function send_message(mesgt) + if base.type(mesgt.body) == "table" then send_multipart(mesgt) + elseif base.type(mesgt.body) == "function" then send_source(mesgt) + else send_string(mesgt) end +end + +-- set defaul headers +local function adjust_headers(mesgt) + local lower = lower_headers(mesgt.headers) + lower["date"] = lower["date"] or + os.date("!%a, %d %b %Y %H:%M:%S ") .. (mesgt.zone or _M.ZONE) + lower["x-mailer"] = lower["x-mailer"] or socket._VERSION + -- this can't be overriden + lower["mime-version"] = "1.0" + return lower +end + +function _M.message(mesgt) + mesgt.headers = adjust_headers(mesgt) + -- create and return message source + local co = coroutine.create(function() send_message(mesgt) end) + return function() + local ret, a, b = coroutine.resume(co) + if ret then return a, b + else return nil, a end + end +end + +--------------------------------------------------------------------------- +-- High level SMTP API +----------------------------------------------------------------------------- +_M.send = socket.protect(function(mailt) + local s = _M.open(mailt.server, mailt.port, mailt.create) + local ext = s:greet(mailt.domain) + s:auth(mailt.user, mailt.password, ext) + s:send(mailt) + s:quit() + return s:close() +end) + +return _M \ No newline at end of file diff --git a/socket/tp.lua b/socket/tp.lua new file mode 100644 index 0000000..cbeff56 --- /dev/null +++ b/socket/tp.lua @@ -0,0 +1,126 @@ +----------------------------------------------------------------------------- +-- Unified SMTP/FTP subsystem +-- LuaSocket toolkit. +-- Author: Diego Nehab +----------------------------------------------------------------------------- + +----------------------------------------------------------------------------- +-- Declare module and import dependencies +----------------------------------------------------------------------------- +local base = _G +local string = require("string") +local socket = require("socket") +local ltn12 = require("ltn12") + +socket.tp = {} +local _M = socket.tp + +----------------------------------------------------------------------------- +-- Program constants +----------------------------------------------------------------------------- +_M.TIMEOUT = 60 + +----------------------------------------------------------------------------- +-- Implementation +----------------------------------------------------------------------------- +-- gets server reply (works for SMTP and FTP) +local function get_reply(c) + local code, current, sep + local line, err = c:receive() + local reply = line + if err then return nil, err end + code, sep = socket.skip(2, string.find(line, "^(%d%d%d)(.?)")) + if not code then return nil, "invalid server reply" end + if sep == "-" then -- reply is multiline + repeat + line, err = c:receive() + if err then return nil, err end + current, sep = socket.skip(2, string.find(line, "^(%d%d%d)(.?)")) + reply = reply .. "\n" .. line + -- reply ends with same code + until code == current and sep == " " + end + return code, reply +end + +-- metatable for sock object +local metat = { __index = {} } + +function metat.__index:check(ok) + local code, reply = get_reply(self.c) + if not code then return nil, reply end + if base.type(ok) ~= "function" then + if base.type(ok) == "table" then + for i, v in base.ipairs(ok) do + if string.find(code, v) then + return base.tonumber(code), reply + end + end + return nil, reply + else + if string.find(code, ok) then return base.tonumber(code), reply + else return nil, reply end + end + else return ok(base.tonumber(code), reply) end +end + +function metat.__index:command(cmd, arg) + cmd = string.upper(cmd) + if arg then + return self.c:send(cmd .. " " .. arg.. "\r\n") + else + return self.c:send(cmd .. "\r\n") + end +end + +function metat.__index:sink(snk, pat) + local chunk, err = c:receive(pat) + return snk(chunk, err) +end + +function metat.__index:send(data) + return self.c:send(data) +end + +function metat.__index:receive(pat) + return self.c:receive(pat) +end + +function metat.__index:getfd() + return self.c:getfd() +end + +function metat.__index:dirty() + return self.c:dirty() +end + +function metat.__index:getcontrol() + return self.c +end + +function metat.__index:source(source, step) + local sink = socket.sink("keep-open", self.c) + local ret, err = ltn12.pump.all(source, sink, step or ltn12.pump.step) + return ret, err +end + +-- closes the underlying c +function metat.__index:close() + self.c:close() + return 1 +end + +-- connect with server and return c object +function _M.connect(host, port, timeout, create) + local c, e = (create or socket.tcp)() + if not c then return nil, e end + c:settimeout(timeout or _M.TIMEOUT) + local r, e = c:connect(host, port) + if not r then + c:close() + return nil, e + end + return base.setmetatable({c = c}, metat) +end + +return _M \ No newline at end of file diff --git a/socket/url.lua b/socket/url.lua new file mode 100644 index 0000000..7809535 --- /dev/null +++ b/socket/url.lua @@ -0,0 +1,307 @@ +----------------------------------------------------------------------------- +-- URI parsing, composition and relative URL resolution +-- LuaSocket toolkit. +-- Author: Diego Nehab +----------------------------------------------------------------------------- + +----------------------------------------------------------------------------- +-- Declare module +----------------------------------------------------------------------------- +local string = require("string") +local base = _G +local table = require("table") +local socket = require("socket") + +socket.url = {} +local _M = socket.url + +----------------------------------------------------------------------------- +-- Module version +----------------------------------------------------------------------------- +_M._VERSION = "URL 1.0.3" + +----------------------------------------------------------------------------- +-- Encodes a string into its escaped hexadecimal representation +-- Input +-- s: binary string to be encoded +-- Returns +-- escaped representation of string binary +----------------------------------------------------------------------------- +function _M.escape(s) + return (string.gsub(s, "([^A-Za-z0-9_])", function(c) + return string.format("%%%02x", string.byte(c)) + end)) +end + +----------------------------------------------------------------------------- +-- Protects a path segment, to prevent it from interfering with the +-- url parsing. +-- Input +-- s: binary string to be encoded +-- Returns +-- escaped representation of string binary +----------------------------------------------------------------------------- +local function make_set(t) + local s = {} + for i,v in base.ipairs(t) do + s[t[i]] = 1 + end + return s +end + +-- these are allowed withing a path segment, along with alphanum +-- other characters must be escaped +local segment_set = make_set { + "-", "_", ".", "!", "~", "*", "'", "(", + ")", ":", "@", "&", "=", "+", "$", ",", +} + +local function protect_segment(s) + return string.gsub(s, "([^A-Za-z0-9_])", function (c) + if segment_set[c] then return c + else return string.format("%%%02x", string.byte(c)) end + end) +end + +----------------------------------------------------------------------------- +-- Encodes a string into its escaped hexadecimal representation +-- Input +-- s: binary string to be encoded +-- Returns +-- escaped representation of string binary +----------------------------------------------------------------------------- +function _M.unescape(s) + return (string.gsub(s, "%%(%x%x)", function(hex) + return string.char(base.tonumber(hex, 16)) + end)) +end + +----------------------------------------------------------------------------- +-- Builds a path from a base path and a relative path +-- Input +-- base_path +-- relative_path +-- Returns +-- corresponding absolute path +----------------------------------------------------------------------------- +local function absolute_path(base_path, relative_path) + if string.sub(relative_path, 1, 1) == "/" then return relative_path end + local path = string.gsub(base_path, "[^/]*$", "") + path = path .. relative_path + path = string.gsub(path, "([^/]*%./)", function (s) + if s ~= "./" then return s else return "" end + end) + path = string.gsub(path, "/%.$", "/") + local reduced + while reduced ~= path do + reduced = path + path = string.gsub(reduced, "([^/]*/%.%./)", function (s) + if s ~= "../../" then return "" else return s end + end) + end + path = string.gsub(reduced, "([^/]*/%.%.)$", function (s) + if s ~= "../.." then return "" else return s end + end) + return path +end + +----------------------------------------------------------------------------- +-- Parses a url and returns a table with all its parts according to RFC 2396 +-- The following grammar describes the names given to the URL parts +-- ::= :///;?# +-- ::= @: +-- ::= [:] +-- :: = {/} +-- Input +-- url: uniform resource locator of request +-- default: table with default values for each field +-- Returns +-- table with the following fields, where RFC naming conventions have +-- been preserved: +-- scheme, authority, userinfo, user, password, host, port, +-- path, params, query, fragment +-- Obs: +-- the leading '/' in {/} is considered part of +----------------------------------------------------------------------------- +function _M.parse(url, default) + -- initialize default parameters + local parsed = {} + for i,v in base.pairs(default or parsed) do parsed[i] = v end + -- empty url is parsed to nil + if not url or url == "" then return nil, "invalid url" end + -- remove whitespace + -- url = string.gsub(url, "%s", "") + -- get fragment + url = string.gsub(url, "#(.*)$", function(f) + parsed.fragment = f + return "" + end) + -- get scheme + url = string.gsub(url, "^([%w][%w%+%-%.]*)%:", + function(s) parsed.scheme = s; return "" end) + -- get authority + url = string.gsub(url, "^//([^/]*)", function(n) + parsed.authority = n + return "" + end) + -- get query string + url = string.gsub(url, "%?(.*)", function(q) + parsed.query = q + return "" + end) + -- get params + url = string.gsub(url, "%;(.*)", function(p) + parsed.params = p + return "" + end) + -- path is whatever was left + if url ~= "" then parsed.path = url end + local authority = parsed.authority + if not authority then return parsed end + authority = string.gsub(authority,"^([^@]*)@", + function(u) parsed.userinfo = u; return "" end) + authority = string.gsub(authority, ":([^:%]]*)$", + function(p) parsed.port = p; return "" end) + if authority ~= "" then + -- IPv6? + parsed.host = string.match(authority, "^%[(.+)%]$") or authority + end + local userinfo = parsed.userinfo + if not userinfo then return parsed end + userinfo = string.gsub(userinfo, ":([^:]*)$", + function(p) parsed.password = p; return "" end) + parsed.user = userinfo + return parsed +end + +----------------------------------------------------------------------------- +-- Rebuilds a parsed URL from its components. +-- Components are protected if any reserved or unallowed characters are found +-- Input +-- parsed: parsed URL, as returned by parse +-- Returns +-- a stringing with the corresponding URL +----------------------------------------------------------------------------- +function _M.build(parsed) + local ppath = _M.parse_path(parsed.path or "") + local url = _M.build_path(ppath) + if parsed.params then url = url .. ";" .. parsed.params end + if parsed.query then url = url .. "?" .. parsed.query end + local authority = parsed.authority + if parsed.host then + authority = parsed.host + if string.find(authority, ":") then -- IPv6? + authority = "[" .. authority .. "]" + end + if parsed.port then authority = authority .. ":" .. parsed.port end + local userinfo = parsed.userinfo + if parsed.user then + userinfo = parsed.user + if parsed.password then + userinfo = userinfo .. ":" .. parsed.password + end + end + if userinfo then authority = userinfo .. "@" .. authority end + end + if authority then url = "//" .. authority .. url end + if parsed.scheme then url = parsed.scheme .. ":" .. url end + if parsed.fragment then url = url .. "#" .. parsed.fragment end + -- url = string.gsub(url, "%s", "") + return url +end + +----------------------------------------------------------------------------- +-- Builds a absolute URL from a base and a relative URL according to RFC 2396 +-- Input +-- base_url +-- relative_url +-- Returns +-- corresponding absolute url +----------------------------------------------------------------------------- +function _M.absolute(base_url, relative_url) + if base.type(base_url) == "table" then + base_parsed = base_url + base_url = _M.build(base_parsed) + else + base_parsed = _M.parse(base_url) + end + local relative_parsed = _M.parse(relative_url) + if not base_parsed then return relative_url + elseif not relative_parsed then return base_url + elseif relative_parsed.scheme then return relative_url + else + relative_parsed.scheme = base_parsed.scheme + if not relative_parsed.authority then + relative_parsed.authority = base_parsed.authority + if not relative_parsed.path then + relative_parsed.path = base_parsed.path + if not relative_parsed.params then + relative_parsed.params = base_parsed.params + if not relative_parsed.query then + relative_parsed.query = base_parsed.query + end + end + else + relative_parsed.path = absolute_path(base_parsed.path or "", + relative_parsed.path) + end + end + return _M.build(relative_parsed) + end +end + +----------------------------------------------------------------------------- +-- Breaks a path into its segments, unescaping the segments +-- Input +-- path +-- Returns +-- segment: a table with one entry per segment +----------------------------------------------------------------------------- +function _M.parse_path(path) + local parsed = {} + path = path or "" + --path = string.gsub(path, "%s", "") + string.gsub(path, "([^/]+)", function (s) table.insert(parsed, s) end) + for i = 1, #parsed do + parsed[i] = _M.unescape(parsed[i]) + end + if string.sub(path, 1, 1) == "/" then parsed.is_absolute = 1 end + if string.sub(path, -1, -1) == "/" then parsed.is_directory = 1 end + return parsed +end + +----------------------------------------------------------------------------- +-- Builds a path component from its segments, escaping protected characters. +-- Input +-- parsed: path segments +-- unsafe: if true, segments are not protected before path is built +-- Returns +-- path: corresponding path stringing +----------------------------------------------------------------------------- +function _M.build_path(parsed, unsafe) + local path = "" + local n = #parsed + if unsafe then + for i = 1, n-1 do + path = path .. parsed[i] + path = path .. "/" + end + if n > 0 then + path = path .. parsed[n] + if parsed.is_directory then path = path .. "/" end + end + else + for i = 1, n-1 do + path = path .. protect_segment(parsed[i]) + path = path .. "/" + end + if n > 0 then + path = path .. protect_segment(parsed[n]) + if parsed.is_directory then path = path .. "/" end + end + end + if parsed.is_absolute then path = "/" .. path end + return path +end + +return _M diff --git a/ssl.dll b/ssl.dll new file mode 100644 index 0000000..444c9c3 Binary files /dev/null and b/ssl.dll differ diff --git a/ssl.lua b/ssl.lua new file mode 100644 index 0000000..3bd236b --- /dev/null +++ b/ssl.lua @@ -0,0 +1,189 @@ +------------------------------------------------------------------------------ +-- LuaSec 0.7 +-- +-- Copyright (C) 2006-2018 Bruno Silvestre +-- +------------------------------------------------------------------------------ + +local core = require("ssl.core") +local context = require("ssl.context") +local x509 = require("ssl.x509") +local config = require("ssl.config") + +local unpack = table.unpack or unpack + +-- We must prevent the contexts to be collected before the connections, +-- otherwise the C registry will be cleared. +local registry = setmetatable({}, {__mode="k"}) + +-- +-- +-- +local function optexec(func, param, ctx) + if param then + if type(param) == "table" then + return func(ctx, unpack(param)) + else + return func(ctx, param) + end + end + return true +end + +-- +-- +-- +local function newcontext(cfg) + local succ, msg, ctx + -- Create the context + ctx, msg = context.create(cfg.protocol) + if not ctx then return nil, msg end + -- Mode + succ, msg = context.setmode(ctx, cfg.mode) + if not succ then return nil, msg end + -- Load the key + if cfg.key then + if cfg.password and + type(cfg.password) ~= "function" and + type(cfg.password) ~= "string" + then + return nil, "invalid password type" + end + succ, msg = context.loadkey(ctx, cfg.key, cfg.password) + if not succ then return nil, msg end + end + -- Load the certificate + if cfg.certificate then + succ, msg = context.loadcert(ctx, cfg.certificate) + if not succ then return nil, msg end + if cfg.key and context.checkkey then + succ = context.checkkey(ctx) + if not succ then return nil, "private key does not match public key" end + end + end + -- Load the CA certificates + if cfg.cafile or cfg.capath then + succ, msg = context.locations(ctx, cfg.cafile, cfg.capath) + if not succ then return nil, msg end + end + -- Set SSL ciphers + if cfg.ciphers then + succ, msg = context.setcipher(ctx, cfg.ciphers) + if not succ then return nil, msg end + end + -- Set the verification options + succ, msg = optexec(context.setverify, cfg.verify, ctx) + if not succ then return nil, msg end + -- Set SSL options + succ, msg = optexec(context.setoptions, cfg.options, ctx) + if not succ then return nil, msg end + -- Set the depth for certificate verification + if cfg.depth then + succ, msg = context.setdepth(ctx, cfg.depth) + if not succ then return nil, msg end + end + + -- NOTE: Setting DH parameters and elliptic curves needs to come after + -- setoptions(), in case the user has specified the single_{dh,ecdh}_use + -- options. + + -- Set DH parameters + if cfg.dhparam then + if type(cfg.dhparam) ~= "function" then + return nil, "invalid DH parameter type" + end + context.setdhparam(ctx, cfg.dhparam) + end + + -- Set elliptic curves + if (not config.algorithms.ec) and (cfg.curve or cfg.curveslist) then + return false, "elliptic curves not supported" + end + if config.capabilities.curves_list and cfg.curveslist then + succ, msg = context.setcurveslist(ctx, cfg.curveslist) + if not succ then return nil, msg end + elseif cfg.curve then + succ, msg = context.setcurve(ctx, cfg.curve) + if not succ then return nil, msg end + end + + -- Set extra verification options + if cfg.verifyext and ctx.setverifyext then + succ, msg = optexec(ctx.setverifyext, cfg.verifyext, ctx) + if not succ then return nil, msg end + end + + return ctx +end + +-- +-- +-- +local function wrap(sock, cfg) + local ctx, msg + if type(cfg) == "table" then + ctx, msg = newcontext(cfg) + if not ctx then return nil, msg end + else + ctx = cfg + end + local s, msg = core.create(ctx) + if s then + core.setfd(s, sock:getfd()) + sock:setfd(core.SOCKET_INVALID) + registry[s] = ctx + return s + end + return nil, msg +end + +-- +-- Extract connection information. +-- +local function info(ssl, field) + local str, comp, err, protocol + comp, err = core.compression(ssl) + if err then + return comp, err + end + -- Avoid parser + if field == "compression" then + return comp + end + local info = {compression = comp} + str, info.bits, info.algbits, protocol = core.info(ssl) + if str then + info.cipher, info.protocol, info.key, + info.authentication, info.encryption, info.mac = + string.match(str, + "^(%S+)%s+(%S+)%s+Kx=(%S+)%s+Au=(%S+)%s+Enc=(%S+)%s+Mac=(%S+)") + info.export = (string.match(str, "%sexport%s*$") ~= nil) + end + if protocol then + info.protocol = protocol + end + if field then + return info[field] + end + -- Empty? + return ( (next(info)) and info ) +end + +-- +-- Set method for SSL connections. +-- +core.setmethod("info", info) + +-------------------------------------------------------------------------------- +-- Export module +-- + +local _M = { + _VERSION = "0.7", + _COPYRIGHT = core.copyright(), + loadcertificate = x509.load, + newcontext = newcontext, + wrap = wrap, +} + +return _M diff --git a/ssl/https.lua b/ssl/https.lua new file mode 100644 index 0000000..d1b708a --- /dev/null +++ b/ssl/https.lua @@ -0,0 +1,143 @@ +---------------------------------------------------------------------------- +-- LuaSec 0.7 +-- Copyright (C) 2009-2018 PUC-Rio +-- +-- Author: Pablo Musa +-- Author: Tomas Guisasola +--------------------------------------------------------------------------- + +local socket = require("socket") +local ssl = require("ssl") +local ltn12 = require("ltn12") +local http = require("socket.http") +local url = require("socket.url") + +local try = socket.try + +-- +-- Module +-- +local _M = { + _VERSION = "0.7", + _COPYRIGHT = "LuaSec 0.7 - Copyright (C) 2009-2018 PUC-Rio", + PORT = 443, +} + +-- TLS configuration +local cfg = { + protocol = "any", + options = {"all", "no_sslv2", "no_sslv3"}, + verify = "none", +} + +-------------------------------------------------------------------- +-- Auxiliar Functions +-------------------------------------------------------------------- + +-- Insert default HTTPS port. +local function default_https_port(u) + return url.build(url.parse(u, {port = _M.PORT})) +end + +-- Convert an URL to a table according to Luasocket needs. +local function urlstring_totable(url, body, result_table) + url = { + url = default_https_port(url), + method = body and "POST" or "GET", + sink = ltn12.sink.table(result_table) + } + if body then + url.source = ltn12.source.string(body) + url.headers = { + ["content-length"] = #body, + ["content-type"] = "application/x-www-form-urlencoded", + } + end + return url +end + +-- Forward calls to the real connection object. +local function reg(conn) + local mt = getmetatable(conn.sock).__index + for name, method in pairs(mt) do + if type(method) == "function" then + conn[name] = function (self, ...) + return method(self.sock, ...) + end + end + end +end + +-- Return a function which performs the SSL/TLS connection. +local function tcp(params) + params = params or {} + -- Default settings + for k, v in pairs(cfg) do + params[k] = params[k] or v + end + -- Force client mode + params.mode = "client" + -- 'create' function for LuaSocket + return function () + local conn = {} + conn.sock = try(socket.tcp()) + local st = getmetatable(conn.sock).__index.settimeout + function conn:settimeout(...) + return st(self.sock, ...) + end + -- Replace TCP's connection function + function conn:connect(host, port) + try(self.sock:connect(host, port)) + self.sock = try(ssl.wrap(self.sock, params)) + self.sock:sni(host) + try(self.sock:dohandshake()) + reg(self, getmetatable(self.sock)) + return 1 + end + return conn + end +end + +-------------------------------------------------------------------- +-- Main Function +-------------------------------------------------------------------- + +-- Make a HTTP request over secure connection. This function receives +-- the same parameters of LuaSocket's HTTP module (except 'proxy' and +-- 'redirect') plus LuaSec parameters. +-- +-- @param url mandatory (string or table) +-- @param body optional (string) +-- @return (string if url == string or 1), code, headers, status +-- +local function request(url, body) + local result_table = {} + local stringrequest = type(url) == "string" + if stringrequest then + url = urlstring_totable(url, body, result_table) + else + url.url = default_https_port(url.url) + end + if http.PROXY or url.proxy then + return nil, "proxy not supported" + elseif url.redirect then + return nil, "redirect not supported" + elseif url.create then + return nil, "create function not permitted" + end + -- New 'create' function to establish a secure connection + url.create = tcp(url) + local res, code, headers, status = http.request(url) + if res and stringrequest then + return table.concat(result_table), code, headers, status + end + return res, code, headers, status +end + +-------------------------------------------------------------------------------- +-- Export module +-- + +_M.request = request + +return _M diff --git a/websocket.lua b/websocket.lua new file mode 100644 index 0000000..4979057 --- /dev/null +++ b/websocket.lua @@ -0,0 +1,12 @@ +local frame = require'websocket.frame' + +return { + client = require'websocket.client', + server = require'websocket.server', + CONTINUATION = frame.CONTINUATION, + TEXT = frame.TEXT, + BINARY = frame.BINARY, + CLOSE = frame.CLOSE, + PING = frame.PING, + PONG = frame.PONG +} diff --git a/websocket/bit.lua b/websocket/bit.lua new file mode 100644 index 0000000..f8fc685 --- /dev/null +++ b/websocket/bit.lua @@ -0,0 +1,10 @@ +local has_bit32,bit = pcall(require,'bit32') +if has_bit32 then + -- lua 5.2 / bit32 library + bit.rol = bit.lrotate + bit.ror = bit.rrotate + return bit +else + -- luajit / lua 5.1 + luabitop + return require'bit' +end diff --git a/websocket/client.lua b/websocket/client.lua new file mode 100644 index 0000000..b31e90a --- /dev/null +++ b/websocket/client.lua @@ -0,0 +1,7 @@ +return setmetatable({},{__index = function(self, name) + if name == 'new' then name = 'sync' end + local backend = require("websocket.client_" .. name) + self[name] = backend + if name == 'sync' then self.new = backend end + return backend +end}) diff --git a/websocket/client_copas.lua b/websocket/client_copas.lua new file mode 100644 index 0000000..f404706 --- /dev/null +++ b/websocket/client_copas.lua @@ -0,0 +1,40 @@ +local socket = require'socket' +local sync = require'websocket.sync' +local tools = require'websocket.tools' + +local new = function(ws) + ws = ws or {} + local copas = require'copas' + + local self = {} + + self.sock_connect = function(self,host,port) + self.sock = socket.tcp() + if ws.timeout ~= nil then + self.sock:settimeout(ws.timeout) + end + local _,err = copas.connect(self.sock,host,port) + if err and err ~= 'already connected' then + self.sock:close() + return nil,err + end + end + + self.sock_send = function(self,...) + return copas.send(self.sock,...) + end + + self.sock_receive = function(self,...) + return copas.receive(self.sock,...) + end + + self.sock_close = function(self) + self.sock:shutdown() + self.sock:close() + end + + self = sync.extend(self) + return self +end + +return new diff --git a/websocket/client_ev.lua b/websocket/client_ev.lua new file mode 100644 index 0000000..d9564c2 --- /dev/null +++ b/websocket/client_ev.lua @@ -0,0 +1,248 @@ + +local socket = require'socket' +local tools = require'websocket.tools' +local frame = require'websocket.frame' +local handshake = require'websocket.handshake' +local debug = require'debug' +local tconcat = table.concat +local tinsert = table.insert + +local ev = function(ws) + ws = ws or {} + local ev = require'ev' + local sock + local loop = ws.loop or ev.Loop.default + local fd + local message_io + local handshake_io + local send_io_stop + local async_send + local self = {} + self.state = 'CLOSED' + local close_timer + local user_on_message + local user_on_close + local user_on_open + local user_on_error + local cleanup = function() + if close_timer then + close_timer:stop(loop) + close_timer = nil + end + if handshake_io then + handshake_io:stop(loop) + handshake_io:clear_pending(loop) + handshake_io = nil + end + if send_io_stop then + send_io_stop() + send_io_stop = nil + end + if message_io then + message_io:stop(loop) + message_io:clear_pending(loop) + message_io = nil + end + if sock then + sock:shutdown() + sock:close() + sock = nil + end + end + + local on_close = function(was_clean,code,reason) + cleanup() + self.state = 'CLOSED' + if user_on_close then + user_on_close(self,was_clean,code,reason or '') + end + end + local on_error = function(err,dont_cleanup) + if not dont_cleanup then + cleanup() + end + if user_on_error then + user_on_error(self,err) + else + print('Error',err) + end + end + local on_open = function(_,headers) + self.state = 'OPEN' + if user_on_open then + user_on_open(self,headers['sec-websocket-protocol'],headers) + end + end + local handle_socket_err = function(err,io,sock) + if self.state == 'OPEN' then + on_close(false,1006,err) + elseif self.state ~= 'CLOSED' then + on_error(err) + end + end + local on_message = function(message,opcode) + if opcode == frame.TEXT or opcode == frame.BINARY then + if user_on_message then + user_on_message(self,message,opcode) + end + elseif opcode == frame.CLOSE then + if self.state ~= 'CLOSING' then + self.state = 'CLOSING' + local code,reason = frame.decode_close(message) + local encoded = frame.encode_close(code) + encoded = frame.encode(encoded,frame.CLOSE,true) + async_send(encoded, + function() + on_close(true,code or 1005,reason) + end,handle_socket_err) + else + on_close(true,1005,'') + end + end + end + + self.send = function(_,message,opcode) + local encoded = frame.encode(message,opcode or frame.TEXT,true) + async_send(encoded, nil, handle_socket_err) + end + + self.connect = function(_,url,ws_protocol) + if self.state ~= 'CLOSED' then + on_error('wrong state',true) + return + end + local protocol,host,port,uri = tools.parse_url(url) + if protocol ~= 'ws' then + on_error('bad protocol') + return + end + local ws_protocols_tbl = {''} + if type(ws_protocol) == 'string' then + ws_protocols_tbl = {ws_protocol} + elseif type(ws_protocol) == 'table' then + ws_protocols_tbl = ws_protocol + end + self.state = 'CONNECTING' + assert(not sock) + sock = socket.tcp() + fd = sock:getfd() + assert(fd > -1) + -- set non blocking + sock:settimeout(0) + sock:setoption('tcp-nodelay',true) + async_send,send_io_stop = require'websocket.ev_common'.async_send(sock,loop) + handshake_io = ev.IO.new( + function(loop,connect_io) + connect_io:stop(loop) + local key = tools.generate_key() + local req = handshake.upgrade_request + { + key = key, + host = host, + port = port, + protocols = ws_protocols_tbl, + origin = ws.origin, + uri = uri + } + async_send( + req, + function() + local resp = {} + local response = '' + local read_upgrade = function(loop,read_io) + -- this seems to be possible, i don't understand why though :( + if not sock then + read_io:stop(loop) + handshake_io = nil + return + end + repeat + local byte,err,pp = sock:receive(1) + if byte then + response = response..byte + elseif err then + if err == 'timeout' then + return + else + read_io:stop(loop) + on_error('accept failed') + return + end + end + until response:sub(#response-3) == '\r\n\r\n' + read_io:stop(loop) + handshake_io = nil + local headers = handshake.http_headers(response) + local expected_accept = handshake.sec_websocket_accept(key) + if headers['sec-websocket-accept'] ~= expected_accept then + self.state = 'CLOSED' + on_error('accept failed') + return + end + message_io = require'websocket.ev_common'.message_io( + sock,loop, + on_message, + handle_socket_err) + on_open(self, headers) + end + handshake_io = ev.IO.new(read_upgrade,fd,ev.READ) + handshake_io:start(loop)-- handshake + end, + handle_socket_err) + end,fd,ev.WRITE) + local connected,err = sock:connect(host,port) + if connected then + handshake_io:callback()(loop,handshake_io) + elseif err == 'timeout' or err == 'Operation already in progress' then + handshake_io:start(loop)-- connect + else + self.state = 'CLOSED' + on_error(err) + end + end + + self.on_close = function(_,on_close_arg) + user_on_close = on_close_arg + end + + self.on_error = function(_,on_error_arg) + user_on_error = on_error_arg + end + + self.on_open = function(_,on_open_arg) + user_on_open = on_open_arg + end + + self.on_message = function(_,on_message_arg) + user_on_message = on_message_arg + end + + self.close = function(_,code,reason,timeout) + if handshake_io then + handshake_io:stop(loop) + handshake_io:clear_pending(loop) + end + if self.state == 'CONNECTING' then + self.state = 'CLOSING' + on_close(false,1006,'') + return + elseif self.state == 'OPEN' then + self.state = 'CLOSING' + timeout = timeout or 3 + local encoded = frame.encode_close(code or 1000,reason) + encoded = frame.encode(encoded,frame.CLOSE,true) + -- this should let the other peer confirm the CLOSE message + -- by 'echoing' the message. + async_send(encoded) + close_timer = ev.Timer.new(function() + close_timer = nil + on_close(false,1006,'timeout') + end,timeout) + close_timer:start(loop) + end + end + + return self +end + +return ev diff --git a/websocket/client_sync.lua b/websocket/client_sync.lua new file mode 100644 index 0000000..eaac5d7 --- /dev/null +++ b/websocket/client_sync.lua @@ -0,0 +1,38 @@ +local socket = require'socket' +local sync = require'websocket.sync' +local tools = require'websocket.tools' + +local new = function(ws) + ws = ws or {} + local self = {} + + self.sock_connect = function(self,host,port) + self.sock = socket.tcp() + if ws.timeout ~= nil then + self.sock:settimeout(ws.timeout) + end + local _,err = self.sock:connect(host,port) + if err then + self.sock:close() + return nil,err + end + end + + self.sock_send = function(self,...) + return self.sock:send(...) + end + + self.sock_receive = function(self,...) + return self.sock:receive(...) + end + + self.sock_close = function(self) + --self.sock:shutdown() Causes errors? + self.sock:close() + end + + self = sync.extend(self) + return self +end + +return new diff --git a/websocket/ev_common.lua b/websocket/ev_common.lua new file mode 100644 index 0000000..58a2f0e --- /dev/null +++ b/websocket/ev_common.lua @@ -0,0 +1,161 @@ +local ev = require'ev' +local frame = require'websocket.frame' +local tinsert = table.insert +local tconcat = table.concat +local eps = 2^-40 + +local detach = function(f,loop) + if ev.Idle then + ev.Idle.new(function(loop,io) + io:stop(loop) + f() + end):start(loop) + else + ev.Timer.new(function(loop,io) + io:stop(loop) + f() + end,eps):start(loop) + end +end + +local async_send = function(sock,loop) + assert(sock) + loop = loop or ev.Loop.default + local sock_send = sock.send + local buffer + local index + local callbacks = {} + local send = function(loop,write_io) + local len = #buffer + local sent,err,last = sock_send(sock,buffer,index) + if not sent and err ~= 'timeout' then + write_io:stop(loop) + if callbacks.on_err then + if write_io:is_active() then + callbacks.on_err(err) + else + detach(function() + callbacks.on_err(err) + end,loop) + end + end + elseif sent then + local copy = buffer + buffer = nil + index = nil + write_io:stop(loop) + if callbacks.on_sent then + -- detach calling callbacks.on_sent from current + -- exection if thiis call context is not + -- the send io to let send_async(_,on_sent,_) truely + -- behave async. + if write_io:is_active() then + + callbacks.on_sent(copy) + else + -- on_sent is only defined when responding to "on message for close op" + -- so this can happen only once per lifetime of a websocket instance. + -- callbacks.on_sent may be overwritten by a new call to send_async + -- (e.g. due to calling ws:close(...) or ws:send(...)) + local on_sent = callbacks.on_sent + detach(function() + on_sent(copy) + end,loop) + end + end + else + assert(last < len) + index = last + 1 + end + end + local io = ev.IO.new(send,sock:getfd(),ev.WRITE) + local stop = function() + io:stop(loop) + buffer = nil + index = nil + end + local send_async = function(data,on_sent,on_err) + if buffer then + -- a write io is still running + buffer = buffer..data + return #buffer + else + buffer = data + end + callbacks.on_sent = on_sent + callbacks.on_err = on_err + if not io:is_active() then + send(loop,io) + if index ~= nil then + io:start(loop) + end + end + local buffered = (buffer and #buffer - (index or 0)) or 0 + return buffered + end + return send_async,stop +end + +local message_io = function(sock,loop,on_message,on_error) + assert(sock) + assert(loop) + assert(on_message) + assert(on_error) + local last + local frames = {} + local first_opcode + assert(sock:getfd() > -1) + local message_io + local dispatch = function(loop,io) + -- could be stopped meanwhile by on_message function + while message_io:is_active() do + local encoded,err,part = sock:receive(100000) + if err then + if err == 'timeout' and #part == 0 then + return + elseif #part == 0 then + if message_io then + message_io:stop(loop) + end + on_error(err,io,sock) + return + end + end + if last then + encoded = last..(encoded or part) + last = nil + else + encoded = encoded or part + end + + repeat + local decoded,fin,opcode,rest = frame.decode(encoded) + if decoded then + if not first_opcode then + first_opcode = opcode + end + tinsert(frames,decoded) + encoded = rest + if fin == true then + on_message(tconcat(frames),first_opcode) + frames = {} + first_opcode = nil + end + end + until not decoded + if #encoded > 0 then + last = encoded + end + end + end + message_io = ev.IO.new(dispatch,sock:getfd(),ev.READ) + message_io:start(loop) + -- the might be already data waiting (which will not trigger the IO) + dispatch(loop,message_io) + return message_io +end + +return { + async_send = async_send, + message_io = message_io +} diff --git a/websocket/frame.lua b/websocket/frame.lua new file mode 100644 index 0000000..9888089 --- /dev/null +++ b/websocket/frame.lua @@ -0,0 +1,214 @@ +-- Following Websocket RFC: http://tools.ietf.org/html/rfc6455 +local bit = require'websocket.bit' +local band = bit.band +local bxor = bit.bxor +local bor = bit.bor +local tremove = table.remove +local srep = string.rep +local ssub = string.sub +local sbyte = string.byte +local schar = string.char +local band = bit.band +local rshift = bit.rshift +local tinsert = table.insert +local tconcat = table.concat +local mmin = math.min +local mfloor = math.floor +local mrandom = math.random +local unpack = unpack or table.unpack +local tools = require'websocket.tools' +local write_int8 = tools.write_int8 +local write_int16 = tools.write_int16 +local write_int32 = tools.write_int32 +local read_int8 = tools.read_int8 +local read_int16 = tools.read_int16 +local read_int32 = tools.read_int32 + +local bits = function(...) + local n = 0 + for _,bitn in pairs{...} do + n = n + 2^bitn + end + return n +end + +local bit_7 = bits(7) +local bit_0_3 = bits(0,1,2,3) +local bit_0_6 = bits(0,1,2,3,4,5,6) + +-- TODO: improve performance +local xor_mask = function(encoded,mask,payload) + local transformed,transformed_arr = {},{} + -- xor chunk-wise to prevent stack overflow. + -- sbyte and schar multiple in/out values + -- which require stack + for p=1,payload,2000 do + local last = mmin(p+1999,payload) + local original = {sbyte(encoded,p,last)} + for i=1,#original do + local j = (i-1) % 4 + 1 + transformed[i] = bxor(original[i],mask[j]) + end + local xored = schar(unpack(transformed,1,#original)) + tinsert(transformed_arr,xored) + end + return tconcat(transformed_arr) +end + +local encode_header_small = function(header, payload) + return schar(header, payload) +end + +local encode_header_medium = function(header, payload, len) + return schar(header, payload, band(rshift(len, 8), 0xFF), band(len, 0xFF)) +end + +local encode_header_big = function(header, payload, high, low) + return schar(header, payload)..write_int32(high)..write_int32(low) +end + +local encode = function(data,opcode,masked,fin) + local header = opcode or 1-- TEXT is default opcode + if fin == nil or fin == true then + header = bor(header,bit_7) + end + local payload = 0 + if masked then + payload = bor(payload,bit_7) + end + local len = #data + local chunks = {} + if len < 126 then + payload = bor(payload,len) + tinsert(chunks,encode_header_small(header,payload)) + elseif len <= 0xffff then + payload = bor(payload,126) + tinsert(chunks,encode_header_medium(header,payload,len)) + elseif len < 2^53 then + local high = mfloor(len/2^32) + local low = len - high*2^32 + payload = bor(payload,127) + tinsert(chunks,encode_header_big(header,payload,high,low)) + end + if not masked then + tinsert(chunks,data) + else + local m1 = mrandom(0,0xff) + local m2 = mrandom(0,0xff) + local m3 = mrandom(0,0xff) + local m4 = mrandom(0,0xff) + local mask = {m1,m2,m3,m4} + tinsert(chunks,write_int8(m1,m2,m3,m4)) + tinsert(chunks,xor_mask(data,mask,#data)) + end + return tconcat(chunks) +end + +local decode = function(encoded) + local encoded_bak = encoded + if #encoded < 2 then + return nil,2-#encoded + end + local pos,header,payload + pos,header = read_int8(encoded,1) + pos,payload = read_int8(encoded,pos) + local high,low + encoded = ssub(encoded,pos) + local bytes = 2 + local fin = band(header,bit_7) > 0 + local opcode = band(header,bit_0_3) + local mask = band(payload,bit_7) > 0 + payload = band(payload,bit_0_6) + if payload > 125 then + if payload == 126 then + if #encoded < 2 then + return nil,2-#encoded + end + pos,payload = read_int16(encoded,1) + elseif payload == 127 then + if #encoded < 8 then + return nil,8-#encoded + end + pos,high = read_int32(encoded,1) + pos,low = read_int32(encoded,pos) + payload = high*2^32 + low + if payload < 0xffff or payload > 2^53 then + assert(false,'INVALID PAYLOAD '..payload) + end + else + assert(false,'INVALID PAYLOAD '..payload) + end + encoded = ssub(encoded,pos) + bytes = bytes + pos - 1 + end + local decoded + if mask then + local bytes_short = payload + 4 - #encoded + if bytes_short > 0 then + return nil,bytes_short + end + local m1,m2,m3,m4 + pos,m1 = read_int8(encoded,1) + pos,m2 = read_int8(encoded,pos) + pos,m3 = read_int8(encoded,pos) + pos,m4 = read_int8(encoded,pos) + encoded = ssub(encoded,pos) + local mask = { + m1,m2,m3,m4 + } + decoded = xor_mask(encoded,mask,payload) + bytes = bytes + 4 + payload + else + local bytes_short = payload - #encoded + if bytes_short > 0 then + return nil,bytes_short + end + if #encoded > payload then + decoded = ssub(encoded,1,payload) + else + decoded = encoded + end + bytes = bytes + payload + end + return decoded,fin,opcode,encoded_bak:sub(bytes+1),mask +end + +local encode_close = function(code,reason) + if code then + local data = write_int16(code) + if reason then + data = data..tostring(reason) + end + return data + end + return '' +end + +local decode_close = function(data) + local _,code,reason + if data then + if #data > 1 then + _,code = read_int16(data,1) + end + if #data > 2 then + reason = data:sub(3) + end + end + return code,reason +end + +return { + encode = encode, + decode = decode, + encode_close = encode_close, + decode_close = decode_close, + encode_header_small = encode_header_small, + encode_header_medium = encode_header_medium, + encode_header_big = encode_header_big, + CONTINUATION = 0, + TEXT = 1, + BINARY = 2, + CLOSE = 8, + PING = 9, + PONG = 10 +} diff --git a/websocket/handshake.lua b/websocket/handshake.lua new file mode 100644 index 0000000..b90d267 --- /dev/null +++ b/websocket/handshake.lua @@ -0,0 +1,104 @@ +local sha1 = require'websocket.tools'.sha1 +local base64 = require'websocket.tools'.base64 +local tinsert = table.insert + +local guid = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11" + +local sec_websocket_accept = function(sec_websocket_key) + local a = sec_websocket_key..guid + local sha1 = sha1(a) + assert((#sha1 % 2) == 0) + return base64.encode(sha1) +end + +local http_headers = function(request) + local headers = {} + if not request:match('.*HTTP/1%.1') then + return headers + end + request = request:match('[^\r\n]+\r\n(.*)') + local empty_line + for line in request:gmatch('[^\r\n]*\r\n') do + local name,val = line:match('([^%s]+)%s*:%s*([^\r\n]+)') + if name and val then + name = name:lower() + if not name:match('sec%-websocket') then + val = val:lower() + end + if not headers[name] then + headers[name] = val + else + headers[name] = headers[name]..','..val + end + elseif line == '\r\n' then + empty_line = true + else + assert(false,line..'('..#line..')') + end + end + return headers,request:match('\r\n\r\n(.*)') +end + +local upgrade_request = function(req) + local format = string.format + local lines = { + format('GET %s HTTP/1.1',req.uri or ''), + format('Host: %s',req.host), + 'Upgrade: websocket', + 'Connection: Upgrade', + format('Sec-WebSocket-Key: %s',req.key), + format('Sec-WebSocket-Protocol: %s',table.concat(req.protocols,', ')), + 'Sec-WebSocket-Version: 13', + } + if req.origin then + tinsert(lines,string.format('Origin: %s',req.origin)) + end + if req.port and req.port ~= 80 then + lines[2] = format('Host: %s:%d',req.host,req.port) + end + tinsert(lines,'\r\n') + return table.concat(lines,'\r\n') +end + +local accept_upgrade = function(request,protocols) + local headers = http_headers(request) + if headers['upgrade'] ~= 'websocket' or + not headers['connection'] or + not headers['connection']:match('upgrade') or + headers['sec-websocket-key'] == nil or + headers['sec-websocket-version'] ~= '13' then + return nil,'HTTP/1.1 400 Bad Request\r\n\r\n' + end + local prot + if headers['sec-websocket-protocol'] then + for protocol in headers['sec-websocket-protocol']:gmatch('([^,%s]+)%s?,?') do + for _,supported in ipairs(protocols) do + if supported == protocol then + prot = protocol + break + end + end + if prot then + break + end + end + end + local lines = { + 'HTTP/1.1 101 Switching Protocols', + 'Upgrade: websocket', + 'Connection: '..headers['connection'], + string.format('Sec-WebSocket-Accept: %s',sec_websocket_accept(headers['sec-websocket-key'])), + } + if prot then + tinsert(lines,string.format('Sec-WebSocket-Protocol: %s',prot)) + end + tinsert(lines,'\r\n') + return table.concat(lines,'\r\n'),prot +end + +return { + sec_websocket_accept = sec_websocket_accept, + http_headers = http_headers, + accept_upgrade = accept_upgrade, + upgrade_request = upgrade_request, +} diff --git a/websocket/server.lua b/websocket/server.lua new file mode 100644 index 0000000..4b5e38f --- /dev/null +++ b/websocket/server.lua @@ -0,0 +1,5 @@ +return setmetatable({},{__index = function(self, name) + local backend = require("websocket.server_" .. name) + self[name] = backend + return backend +end}) diff --git a/websocket/server_copas.lua b/websocket/server_copas.lua new file mode 100644 index 0000000..ea13006 --- /dev/null +++ b/websocket/server_copas.lua @@ -0,0 +1,146 @@ + +local socket = require'socket' +local copas = require'copas' +local tools = require'websocket.tools' +local frame = require'websocket.frame' +local handshake = require'websocket.handshake' +local sync = require'websocket.sync' +local tconcat = table.concat +local tinsert = table.insert + +local clients = {} + +local client = function(sock,protocol) + local copas = require'copas' + + local self = {} + + self.state = 'OPEN' + self.is_server = true + + self.sock_send = function(self,...) + return copas.send(sock,...) + end + + self.sock_receive = function(self,...) + return copas.receive(sock,...) + end + + self.sock_close = function(self) + sock:shutdown() + sock:close() + end + + self = sync.extend(self) + + self.on_close = function(self) + clients[protocol][self] = nil + end + + self.broadcast = function(self,...) + for client in pairs(clients[protocol]) do + if client ~= self then + client:send(...) + end + end + self:send(...) + end + + return self +end + +local listen = function(opts) + + local copas = require'copas' + assert(opts and (opts.protocols or opts.default)) + local on_error = opts.on_error or function(s) print(s) end + local listener,err = socket.bind(opts.interface or '*',opts.port or 80) + if err then + error(err) + end + local protocols = {} + if opts.protocols then + for protocol in pairs(opts.protocols) do + clients[protocol] = {} + tinsert(protocols,protocol) + end + end + -- true is the 'magic' index for the default handler + clients[true] = {} + copas.addserver( + listener, + function(sock) + local request = {} + repeat + -- no timeout used, so should either return with line or err + local line,err = copas.receive(sock,'*l') + if line then + request[#request+1] = line + else + sock:close() + if on_error then + on_error('invalid request') + end + return + end + until line == '' + local upgrade_request = tconcat(request,'\r\n') + local response,protocol = handshake.accept_upgrade(upgrade_request,protocols) + if not response then + copas.send(sock,protocol) + sock:close() + if on_error then + on_error('invalid request') + end + return + end + copas.send(sock,response) + local handler + local new_client + local protocol_index + if protocol and opts.protocols[protocol] then + protocol_index = protocol + handler = opts.protocols[protocol] + elseif opts.default then + -- true is the 'magic' index for the default handler + protocol_index = true + handler = opts.default + else + sock:close() + if on_error then + on_error('bad protocol') + end + return + end + new_client = client(sock,protocol_index) + clients[protocol_index][new_client] = true + handler(new_client) + -- this is a dirty trick for preventing + -- copas from automatically and prematurely closing + -- the socket + while new_client.state ~= 'CLOSED' do + local dummy = { + send = function() end, + close = function() end + } + copas.send(dummy) + end + end) + local self = {} + self.close = function(_,keep_clients) + copas.removeserver(listener) + listener = nil + if not keep_clients then + for protocol,clients in pairs(clients) do + for client in pairs(clients) do + client:close() + end + end + end + end + return self +end + +return { + listen = listen +} diff --git a/websocket/server_ev.lua b/websocket/server_ev.lua new file mode 100644 index 0000000..abd7a20 --- /dev/null +++ b/websocket/server_ev.lua @@ -0,0 +1,266 @@ + +local socket = require'socket' +local tools = require'websocket.tools' +local frame = require'websocket.frame' +local handshake = require'websocket.handshake' +local tconcat = table.concat +local tinsert = table.insert +local ev +local loop + +local clients = {} +clients[true] = {} + +local client = function(sock,protocol) + assert(sock) + sock:setoption('tcp-nodelay',true) + local fd = sock:getfd() + local message_io + local close_timer + local async_send = require'websocket.ev_common'.async_send(sock,loop) + local self = {} + self.state = 'OPEN' + self.sock = sock + local user_on_error + local on_error = function(s,err) + if clients[protocol] ~= nil and clients[protocol][self] ~= nil then + clients[protocol][self] = nil + end + if user_on_error then + user_on_error(self,err) + else + print('Websocket server error',err) + end + end + local user_on_close + local on_close = function(was_clean,code,reason) + if clients[protocol] ~= nil and clients[protocol][self] ~= nil then + clients[protocol][self] = nil + end + if close_timer then + close_timer:stop(loop) + close_timer = nil + end + message_io:stop(loop) + self.state = 'CLOSED' + if user_on_close then + user_on_close(self,was_clean,code,reason or '') + end + sock:shutdown() + sock:close() + end + + local handle_sock_err = function(err) + if err == 'closed' then + if self.state ~= 'CLOSED' then + on_close(false,1006,'') + end + else + on_error(err) + end + end + local user_on_message = function() end + local TEXT = frame.TEXT + local BINARY = frame.BINARY + local on_message = function(message,opcode) + if opcode == TEXT or opcode == BINARY then + user_on_message(self,message,opcode) + elseif opcode == frame.CLOSE then + if self.state ~= 'CLOSING' then + self.state = 'CLOSING' + local code,reason = frame.decode_close(message) + local encoded = frame.encode_close(code) + encoded = frame.encode(encoded,frame.CLOSE) + async_send(encoded, + function() + on_close(true,code or 1006,reason) + end,handle_sock_err) + else + on_close(true,1006,'') + end + end + end + + self.send = function(_,message,opcode) + local encoded = frame.encode(message,opcode or frame.TEXT) + return async_send(encoded) + end + + self.on_close = function(_,on_close_arg) + user_on_close = on_close_arg + end + + self.on_error = function(_,on_error_arg) + user_on_error = on_error_arg + end + + self.on_message = function(_,on_message_arg) + user_on_message = on_message_arg + end + + self.broadcast = function(_,...) + for client in pairs(clients[protocol]) do + if client.state == 'OPEN' then + client:send(...) + end + end + end + + self.close = function(_,code,reason,timeout) + if clients[protocol] ~= nil and clients[protocol][self] ~= nil then + clients[protocol][self] = nil + end + if not message_io then + self:start() + end + if self.state == 'OPEN' then + self.state = 'CLOSING' + assert(message_io) + timeout = timeout or 3 + local encoded = frame.encode_close(code or 1000,reason or '') + encoded = frame.encode(encoded,frame.CLOSE) + async_send(encoded) + close_timer = ev.Timer.new(function() + close_timer = nil + on_close(false,1006,'timeout') + end,timeout) + close_timer:start(loop) + end + end + + self.start = function() + message_io = require'websocket.ev_common'.message_io( + sock,loop, + on_message, + handle_sock_err) + end + + + return self +end + +local listen = function(opts) + assert(opts and (opts.protocols or opts.default)) + ev = require'ev' + loop = opts.loop or ev.Loop.default + local user_on_error + local on_error = function(s,err) + if user_on_error then + user_on_error(s,err) + else + print(err) + end + end + local protocols = {} + if opts.protocols then + for protocol in pairs(opts.protocols) do + clients[protocol] = {} + tinsert(protocols,protocol) + end + end + local self = {} + self.on_error = function(self,on_error) + user_on_error = on_error + end + local listener,err = socket.bind(opts.interface or '*',opts.port or 80) + if not listener then + error(err) + end + listener:settimeout(0) + + self.sock = function() + return listener + end + + local listen_io = ev.IO.new( + function() + local client_sock = listener:accept() + client_sock:settimeout(0) + assert(client_sock) + local request = {} + local last + ev.IO.new( + function(loop,read_io) + repeat + local line,err,part = client_sock:receive('*l') + if line then + if last then + line = last..line + last = nil + end + request[#request+1] = line + elseif err ~= 'timeout' then + on_error(self,'Websocket Handshake failed due to socket err:'..err) + read_io:stop(loop) + return + else + last = part + return + end + until line == '' + read_io:stop(loop) + local upgrade_request = tconcat(request,'\r\n') + local response,protocol = handshake.accept_upgrade(upgrade_request,protocols) + if not response then + print('Handshake failed, Request:') + print(upgrade_request) + client_sock:close() + return + end + local index + ev.IO.new( + function(loop,write_io) + local len = #response + local sent,err = client_sock:send(response,index) + if not sent then + write_io:stop(loop) + print('Websocket client closed while handshake',err) + elseif sent == len then + write_io:stop(loop) + local protocol_handler + local new_client + local protocol_index + if protocol and opts.protocols[protocol] then + protocol_index = protocol + protocol_handler = opts.protocols[protocol] + elseif opts.default then + -- true is the 'magic' index for the default handler + protocol_index = true + protocol_handler = opts.default + else + client_sock:close() + if on_error then + on_error('bad protocol') + end + return + end + new_client = client(client_sock,protocol_index) + clients[protocol_index][new_client] = true + protocol_handler(new_client) + new_client:start(loop) + else + assert(sent < len) + index = sent + end + end,client_sock:getfd(),ev.WRITE):start(loop) + end,client_sock:getfd(),ev.READ):start(loop) + end,listener:getfd(),ev.READ) + self.close = function(keep_clients) + listen_io:stop(loop) + listener:close() + listener = nil + if not keep_clients then + for protocol,clients in pairs(clients) do + for client in pairs(clients) do + client:close() + end + end + end + end + listen_io:start(loop) + return self +end + +return { + listen = listen +} diff --git a/websocket/sync.lua b/websocket/sync.lua new file mode 100644 index 0000000..2b94312 --- /dev/null +++ b/websocket/sync.lua @@ -0,0 +1,203 @@ +local frame = require'websocket.frame' +local handshake = require'websocket.handshake' +local tools = require'websocket.tools' +local ssl = require'ssl' +local tinsert = table.insert +local tconcat = table.concat + +local receive = function(self) + if self.state ~= 'OPEN' and not self.is_closing then + return nil,nil,false,1006,'wrong state' + end + local first_opcode + local frames + local bytes = 3 + local encoded = '' + local clean = function(was_clean,code,reason) + self.state = 'CLOSED' + self:sock_close() + if self.on_close then + self:on_close() + end + return nil,nil,was_clean,code,reason or 'closed' + end + while true do + local chunk,err = self:sock_receive(bytes) + if err then + return clean(false,1006,err) + end + encoded = encoded..chunk + local decoded,fin,opcode,_,masked = frame.decode(encoded) + if not self.is_server and masked then + return clean(false,1006,'Websocket receive failed: frame was not masked') + end + if decoded then + if opcode == frame.CLOSE then + if not self.is_closing then + local code,reason = frame.decode_close(decoded) + -- echo code + local msg = frame.encode_close(code) + local encoded = frame.encode(msg,frame.CLOSE,not self.is_server) + local n,err = self:sock_send(encoded) + if n == #encoded then + return clean(true,code,reason) + else + return clean(false,code,err) + end + else + return decoded,opcode + end + end + if not first_opcode then + first_opcode = opcode + end + if not fin then + if not frames then + frames = {} + elseif opcode ~= frame.CONTINUATION then + return clean(false,1002,'protocol error') + end + bytes = 3 + encoded = '' + tinsert(frames,decoded) + elseif not frames then + return decoded,first_opcode + else + tinsert(frames,decoded) + return tconcat(frames),first_opcode + end + else + assert(type(fin) == 'number' and fin > 0) + bytes = fin + end + end + assert(false,'never reach here') +end + +local send = function(self,data,opcode) + if self.state ~= 'OPEN' then + return nil,false,1006,'wrong state' + end + local encoded = frame.encode(data,opcode or frame.TEXT,not self.is_server) + local n,err = self:sock_send(encoded) + if n ~= #encoded then + return nil,self:close(1006,err) + end + return true +end + +local close = function(self,code,reason) + if self.state ~= 'OPEN' then + return false,1006,'wrong state' + end + if self.state == 'CLOSED' then + return false,1006,'wrong state' + end + local msg = frame.encode_close(code or 1000,reason) + local encoded = frame.encode(msg,frame.CLOSE,not self.is_server) + local n,err = self:sock_send(encoded) + local was_clean = false + local code = 1005 + local reason = '' + if n == #encoded then + self.is_closing = true + local rmsg,opcode = self:receive() + if rmsg and opcode == frame.CLOSE then + code,reason = frame.decode_close(rmsg) + was_clean = true + end + else + reason = err + end + self:sock_close() + if self.on_close then + self:on_close() + end + self.state = 'CLOSED' + return was_clean,code,reason or '' +end + +local connect = function(self,ws_url,ws_protocol,ssl_params) + if self.state ~= 'CLOSED' then + return nil,'wrong state',nil + end + local protocol,host,port,uri = tools.parse_url(ws_url) + -- Preconnect (for SSL if needed) + local _,err = self:sock_connect(host,port) + if err then + return nil,err,nil + end + if protocol == 'wss' then + self.sock = ssl.wrap(self.sock, ssl_params) + self.sock:dohandshake() + elseif protocol ~= "ws" then + return nil, 'bad protocol' + end + local ws_protocols_tbl = {''} + if type(ws_protocol) == 'string' then + ws_protocols_tbl = {ws_protocol} + elseif type(ws_protocol) == 'table' then + ws_protocols_tbl = ws_protocol + end + local key = tools.generate_key() + local req = handshake.upgrade_request + { + key = key, + host = host, + port = port, + protocols = ws_protocols_tbl, + uri = uri + } + local n,err = self:sock_send(req) + if n ~= #req then + return nil,err,nil + end + local resp = {} + repeat + local line,err = self:sock_receive('*l') + resp[#resp+1] = line + if err then + return nil,err,nil + end + until line == '' + local response = table.concat(resp,'\r\n') + local headers = handshake.http_headers(response) + local expected_accept = handshake.sec_websocket_accept(key) + if headers['sec-websocket-accept'] ~= expected_accept then + local msg = 'Websocket Handshake failed: Invalid Sec-Websocket-Accept (expected %s got %s)' + return nil,msg:format(expected_accept,headers['sec-websocket-accept'] or 'nil'),headers + end + self.state = 'OPEN' + return true,headers['sec-websocket-protocol'],headers +end + +local extend = function(obj) + assert(obj.sock_send) + assert(obj.sock_receive) + assert(obj.sock_close) + + assert(obj.is_closing == nil) + assert(obj.receive == nil) + assert(obj.send == nil) + assert(obj.close == nil) + assert(obj.connect == nil) + + if not obj.is_server then + assert(obj.sock_connect) + end + + if not obj.state then + obj.state = 'CLOSED' + end + + obj.receive = receive + obj.send = send + obj.close = close + obj.connect = connect + + return obj +end + +return { + extend = extend +} diff --git a/websocket/tools.lua b/websocket/tools.lua new file mode 100644 index 0000000..59586e3 --- /dev/null +++ b/websocket/tools.lua @@ -0,0 +1,203 @@ +local bit = require'websocket.bit' +local mime = require'mime' +local rol = bit.rol +local bxor = bit.bxor +local bor = bit.bor +local band = bit.band +local bnot = bit.bnot +local lshift = bit.lshift +local rshift = bit.rshift +local sunpack = string.unpack +local srep = string.rep +local schar = string.char +local tremove = table.remove +local tinsert = table.insert +local tconcat = table.concat +local mrandom = math.random + +local read_n_bytes = function(str, pos, n) + pos = pos or 1 + return pos+n, string.byte(str, pos, pos + n - 1) +end + +local read_int8 = function(str, pos) + return read_n_bytes(str, pos, 1) +end + +local read_int16 = function(str, pos) + local new_pos,a,b = read_n_bytes(str, pos, 2) + return new_pos, lshift(a, 8) + b +end + +local read_int32 = function(str, pos) + local new_pos,a,b,c,d = read_n_bytes(str, pos, 4) + return new_pos, + lshift(a, 24) + + lshift(b, 16) + + lshift(c, 8 ) + + d +end + +local pack_bytes = string.char + +local write_int8 = pack_bytes + +local write_int16 = function(v) + return pack_bytes(rshift(v, 8), band(v, 0xFF)) +end + +local write_int32 = function(v) + return pack_bytes( + band(rshift(v, 24), 0xFF), + band(rshift(v, 16), 0xFF), + band(rshift(v, 8), 0xFF), + band(v, 0xFF) + ) +end + +-- used for generate key random ops +math.randomseed(os.time()) + +-- SHA1 hashing from luacrypto, if available +local sha1_crypto +local done,crypto = pcall(require,'crypto') +if done then + sha1_crypto = function(msg) + return crypto.digest('sha1',msg,true) + end +end + +-- from wiki article, not particularly clever impl +local sha1_wiki = function(msg) + local h0 = 0x67452301 + local h1 = 0xEFCDAB89 + local h2 = 0x98BADCFE + local h3 = 0x10325476 + local h4 = 0xC3D2E1F0 + + local bits = #msg * 8 + -- append b10000000 + msg = msg..schar(0x80) + + -- 64 bit length will be appended + local bytes = #msg + 8 + + -- 512 bit append stuff + local fill_bytes = 64 - (bytes % 64) + if fill_bytes ~= 64 then + msg = msg..srep(schar(0),fill_bytes) + end + + -- append 64 big endian length + local high = math.floor(bits/2^32) + local low = bits - high*2^32 + msg = msg..write_int32(high)..write_int32(low) + + assert(#msg % 64 == 0,#msg % 64) + + for j=1,#msg,64 do + local chunk = msg:sub(j,j+63) + assert(#chunk==64,#chunk) + local words = {} + local next = 1 + local word + repeat + next,word = read_int32(chunk, next) + tinsert(words, word) + until next > 64 + assert(#words==16) + for i=17,80 do + words[i] = bxor(words[i-3],words[i-8],words[i-14],words[i-16]) + words[i] = rol(words[i],1) + end + local a = h0 + local b = h1 + local c = h2 + local d = h3 + local e = h4 + + for i=1,80 do + local k,f + if i > 0 and i < 21 then + f = bor(band(b,c),band(bnot(b),d)) + k = 0x5A827999 + elseif i > 20 and i < 41 then + f = bxor(b,c,d) + k = 0x6ED9EBA1 + elseif i > 40 and i < 61 then + f = bor(band(b,c),band(b,d),band(c,d)) + k = 0x8F1BBCDC + elseif i > 60 and i < 81 then + f = bxor(b,c,d) + k = 0xCA62C1D6 + end + + local temp = rol(a,5) + f + e + k + words[i] + e = d + d = c + c = rol(b,30) + b = a + a = temp + end + + h0 = h0 + a + h1 = h1 + b + h2 = h2 + c + h3 = h3 + d + h4 = h4 + e + + end + + -- necessary on sizeof(int) == 32 machines + h0 = band(h0,0xffffffff) + h1 = band(h1,0xffffffff) + h2 = band(h2,0xffffffff) + h3 = band(h3,0xffffffff) + h4 = band(h4,0xffffffff) + + return write_int32(h0)..write_int32(h1)..write_int32(h2)..write_int32(h3)..write_int32(h4) +end + +local base64_encode = function(data) + return (mime.b64(data)) +end + +local DEFAULT_PORTS = {ws = 80, wss = 443} + +local parse_url = function(url) + local protocol, address, uri = url:match('^(%w+)://([^/]+)(.*)$') + if not protocol then error('Invalid URL:'..url) end + protocol = protocol:lower() + local host, port = address:match("^(.+):(%d+)$") + if not host then + host = address + port = DEFAULT_PORTS[protocol] + end + if not uri or uri == '' then uri = '/' end + return protocol, host, tonumber(port), uri +end + +local generate_key = function() + local r1 = mrandom(0,0xfffffff) + local r2 = mrandom(0,0xfffffff) + local r3 = mrandom(0,0xfffffff) + local r4 = mrandom(0,0xfffffff) + local key = write_int32(r1)..write_int32(r2)..write_int32(r3)..write_int32(r4) + assert(#key==16,#key) + return base64_encode(key) +end + +return { + sha1 = sha1_crypto or sha1_wiki, + base64 = { + encode = base64_encode + }, + parse_url = parse_url, + generate_key = generate_key, + read_int8 = read_int8, + read_int16 = read_int16, + read_int32 = read_int32, + write_int8 = write_int8, + write_int16 = write_int16, + write_int32 = write_int32, +}