You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
214 lines
5.4 KiB
214 lines
5.4 KiB
-- Following Websocket RFC: http://tools.ietf.org/html/rfc6455
|
|
local bit = require'websocket.bit'
|
|
local band = bit.band
|
|
local bxor = bit.bxor
|
|
local bor = bit.bor
|
|
local tremove = table.remove
|
|
local srep = string.rep
|
|
local ssub = string.sub
|
|
local sbyte = string.byte
|
|
local schar = string.char
|
|
local band = bit.band
|
|
local rshift = bit.rshift
|
|
local tinsert = table.insert
|
|
local tconcat = table.concat
|
|
local mmin = math.min
|
|
local mfloor = math.floor
|
|
local mrandom = math.random
|
|
local unpack = unpack or table.unpack
|
|
local tools = require'websocket.tools'
|
|
local write_int8 = tools.write_int8
|
|
local write_int16 = tools.write_int16
|
|
local write_int32 = tools.write_int32
|
|
local read_int8 = tools.read_int8
|
|
local read_int16 = tools.read_int16
|
|
local read_int32 = tools.read_int32
|
|
|
|
local bits = function(...)
|
|
local n = 0
|
|
for _,bitn in pairs{...} do
|
|
n = n + 2^bitn
|
|
end
|
|
return n
|
|
end
|
|
|
|
local bit_7 = bits(7)
|
|
local bit_0_3 = bits(0,1,2,3)
|
|
local bit_0_6 = bits(0,1,2,3,4,5,6)
|
|
|
|
-- TODO: improve performance
|
|
local xor_mask = function(encoded,mask,payload)
|
|
local transformed,transformed_arr = {},{}
|
|
-- xor chunk-wise to prevent stack overflow.
|
|
-- sbyte and schar multiple in/out values
|
|
-- which require stack
|
|
for p=1,payload,2000 do
|
|
local last = mmin(p+1999,payload)
|
|
local original = {sbyte(encoded,p,last)}
|
|
for i=1,#original do
|
|
local j = (i-1) % 4 + 1
|
|
transformed[i] = bxor(original[i],mask[j])
|
|
end
|
|
local xored = schar(unpack(transformed,1,#original))
|
|
tinsert(transformed_arr,xored)
|
|
end
|
|
return tconcat(transformed_arr)
|
|
end
|
|
|
|
local encode_header_small = function(header, payload)
|
|
return schar(header, payload)
|
|
end
|
|
|
|
local encode_header_medium = function(header, payload, len)
|
|
return schar(header, payload, band(rshift(len, 8), 0xFF), band(len, 0xFF))
|
|
end
|
|
|
|
local encode_header_big = function(header, payload, high, low)
|
|
return schar(header, payload)..write_int32(high)..write_int32(low)
|
|
end
|
|
|
|
local encode = function(data,opcode,masked,fin)
|
|
local header = opcode or 1-- TEXT is default opcode
|
|
if fin == nil or fin == true then
|
|
header = bor(header,bit_7)
|
|
end
|
|
local payload = 0
|
|
if masked then
|
|
payload = bor(payload,bit_7)
|
|
end
|
|
local len = #data
|
|
local chunks = {}
|
|
if len < 126 then
|
|
payload = bor(payload,len)
|
|
tinsert(chunks,encode_header_small(header,payload))
|
|
elseif len <= 0xffff then
|
|
payload = bor(payload,126)
|
|
tinsert(chunks,encode_header_medium(header,payload,len))
|
|
elseif len < 2^53 then
|
|
local high = mfloor(len/2^32)
|
|
local low = len - high*2^32
|
|
payload = bor(payload,127)
|
|
tinsert(chunks,encode_header_big(header,payload,high,low))
|
|
end
|
|
if not masked then
|
|
tinsert(chunks,data)
|
|
else
|
|
local m1 = mrandom(0,0xff)
|
|
local m2 = mrandom(0,0xff)
|
|
local m3 = mrandom(0,0xff)
|
|
local m4 = mrandom(0,0xff)
|
|
local mask = {m1,m2,m3,m4}
|
|
tinsert(chunks,write_int8(m1,m2,m3,m4))
|
|
tinsert(chunks,xor_mask(data,mask,#data))
|
|
end
|
|
return tconcat(chunks)
|
|
end
|
|
|
|
local decode = function(encoded)
|
|
local encoded_bak = encoded
|
|
if #encoded < 2 then
|
|
return nil,2-#encoded
|
|
end
|
|
local pos,header,payload
|
|
pos,header = read_int8(encoded,1)
|
|
pos,payload = read_int8(encoded,pos)
|
|
local high,low
|
|
encoded = ssub(encoded,pos)
|
|
local bytes = 2
|
|
local fin = band(header,bit_7) > 0
|
|
local opcode = band(header,bit_0_3)
|
|
local mask = band(payload,bit_7) > 0
|
|
payload = band(payload,bit_0_6)
|
|
if payload > 125 then
|
|
if payload == 126 then
|
|
if #encoded < 2 then
|
|
return nil,2-#encoded
|
|
end
|
|
pos,payload = read_int16(encoded,1)
|
|
elseif payload == 127 then
|
|
if #encoded < 8 then
|
|
return nil,8-#encoded
|
|
end
|
|
pos,high = read_int32(encoded,1)
|
|
pos,low = read_int32(encoded,pos)
|
|
payload = high*2^32 + low
|
|
if payload < 0xffff or payload > 2^53 then
|
|
assert(false,'INVALID PAYLOAD '..payload)
|
|
end
|
|
else
|
|
assert(false,'INVALID PAYLOAD '..payload)
|
|
end
|
|
encoded = ssub(encoded,pos)
|
|
bytes = bytes + pos - 1
|
|
end
|
|
local decoded
|
|
if mask then
|
|
local bytes_short = payload + 4 - #encoded
|
|
if bytes_short > 0 then
|
|
return nil,bytes_short
|
|
end
|
|
local m1,m2,m3,m4
|
|
pos,m1 = read_int8(encoded,1)
|
|
pos,m2 = read_int8(encoded,pos)
|
|
pos,m3 = read_int8(encoded,pos)
|
|
pos,m4 = read_int8(encoded,pos)
|
|
encoded = ssub(encoded,pos)
|
|
local mask = {
|
|
m1,m2,m3,m4
|
|
}
|
|
decoded = xor_mask(encoded,mask,payload)
|
|
bytes = bytes + 4 + payload
|
|
else
|
|
local bytes_short = payload - #encoded
|
|
if bytes_short > 0 then
|
|
return nil,bytes_short
|
|
end
|
|
if #encoded > payload then
|
|
decoded = ssub(encoded,1,payload)
|
|
else
|
|
decoded = encoded
|
|
end
|
|
bytes = bytes + payload
|
|
end
|
|
return decoded,fin,opcode,encoded_bak:sub(bytes+1),mask
|
|
end
|
|
|
|
local encode_close = function(code,reason)
|
|
if code then
|
|
local data = write_int16(code)
|
|
if reason then
|
|
data = data..tostring(reason)
|
|
end
|
|
return data
|
|
end
|
|
return ''
|
|
end
|
|
|
|
local decode_close = function(data)
|
|
local _,code,reason
|
|
if data then
|
|
if #data > 1 then
|
|
_,code = read_int16(data,1)
|
|
end
|
|
if #data > 2 then
|
|
reason = data:sub(3)
|
|
end
|
|
end
|
|
return code,reason
|
|
end
|
|
|
|
return {
|
|
encode = encode,
|
|
decode = decode,
|
|
encode_close = encode_close,
|
|
decode_close = decode_close,
|
|
encode_header_small = encode_header_small,
|
|
encode_header_medium = encode_header_medium,
|
|
encode_header_big = encode_header_big,
|
|
CONTINUATION = 0,
|
|
TEXT = 1,
|
|
BINARY = 2,
|
|
CLOSE = 8,
|
|
PING = 9,
|
|
PONG = 10
|
|
}
|
|
|