diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 4dd6f16..ea6f5c9 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -37,6 +37,17 @@ jobs: version: v2.1.0 args: false # Will be run as part of `make check` + - name: Setup Python + uses: actions/setup-python@v5 + with: + python-version: '3.12' + + - name: Install protobuf compiler + run: sudo apt-get install -y protobuf-compiler + + - name: Setup schema generator + run: make setup-schema-generator + - name: Check run: make check diff --git a/Makefile b/Makefile index 1d845f7..ae0735d 100644 --- a/Makefile +++ b/Makefile @@ -124,6 +124,40 @@ gen-schema: fi @.venv/bin/python3 tools/gen_lua_proto_schema $(OUTPUT) $(PROTO) +# Generate the base types schema from empty.proto +.PHONY: gen-types +gen-types: + @if [ ! -f .venv/bin/python3 ]; then \ + echo "Python virtual environment not found. Run 'make setup-schema-generator' first."; \ + exit 1; \ + fi + @echo "Generating src/protobuf/types.lua from empty.proto..." + @.venv/bin/python3 tools/gen_lua_proto_schema src/protobuf/types.lua empty.proto + @echo "Generated src/protobuf/types.lua" + +# Check that types.lua matches what would be generated (for CI) +.PHONY: check-types +check-types: + @if [ ! -f .venv/bin/python3 ]; then \ + echo "Python virtual environment not found. Run 'make setup-schema-generator' first."; \ + exit 1; \ + fi + @echo "Checking src/protobuf/types.lua is up to date..." + @mkdir -p build + @.venv/bin/python3 tools/gen_lua_proto_schema build/types.lua.tmp empty.proto + @if diff -q src/protobuf/types.lua build/types.lua.tmp >/dev/null 2>&1; then \ + echo "src/protobuf/types.lua is up to date"; \ + rm -f build/types.lua.tmp; \ + else \ + echo "ERROR: src/protobuf/types.lua is out of date!"; \ + echo "Run 'make gen-types' to regenerate it."; \ + echo ""; \ + echo "Diff:"; \ + diff src/protobuf/types.lua build/types.lua.tmp || true; \ + rm -f build/types.lua.tmp; \ + exit 1; \ + fi + # Format Lua code with stylua .PHONY: format format: @@ -162,7 +196,7 @@ lint: fi .PHONY: check -check: format-check lint +check: format-check lint check-types @echo "Code quality checks complete." # Clean generated files @@ -186,6 +220,8 @@ help: @echo "Schema Generation:" @echo " make setup-schema-generator - Setup Python venv for schema generator" @echo " make gen-schema PROTO= OUTPUT= - Generate Lua schema from proto file(s)" + @echo " make gen-types - Regenerate src/protobuf/types.lua" + @echo " make check-types - Verify types.lua matches empty.proto" @echo "" @echo "Code Quality:" @echo " make format - Format all code (Lua)" diff --git a/empty.proto b/empty.proto new file mode 100644 index 0000000..9022b92 --- /dev/null +++ b/empty.proto @@ -0,0 +1,2 @@ +syntax = "proto3"; +package empty; \ No newline at end of file diff --git a/src/protobuf/init.lua b/src/protobuf/init.lua index 4908ca5..0229f9f 100644 --- a/src/protobuf/init.lua +++ b/src/protobuf/init.lua @@ -8,7 +8,7 @@ local bit64 = bitn.bit64 --- Check if a value is a list (sequential table). --- @param t any The value to check. --- @return boolean is_list True if the value is a list. -local function IsList(t) +local function is_list(t) if type(t) ~= "table" then return false end @@ -38,7 +38,7 @@ function Protobuf.version() end --- Encodes an integer into a varint byte sequence. ---- @param value number|boolean|Int64HighLow The value to encode. Can be a number, boolean, or {high, low} pair for 64-bit values. +--- @param value integer|boolean|Int64HighLow The value to encode. Can be a number, boolean, or {high, low} pair for 64-bit values. --- @return string bytes The encoded varint byte sequence. function Protobuf.encode_varint(value) if type(value) == "boolean" then @@ -46,11 +46,10 @@ function Protobuf.encode_varint(value) end -- If value is a table, assume it's {high, low} format for 64-bit - if type(value) == "table" then + if bit64.is_int64(value) then --- @cast value Int64HighLow local bytes = {} - --- @type Int64HighLow - local v = { value[1], value[2] } -- Copy the input + local v = bit64.new(value[1], value[2]) -- Copy the input repeat -- Extract low 7 bits @@ -68,6 +67,7 @@ function Protobuf.encode_varint(value) return table.concat(bytes) end + --- @cast value -Int64HighLow -- For values that fit in 32 bits, use bit operations (fast path) if value >= 0 and value < 0x100000000 then @@ -86,8 +86,7 @@ function Protobuf.encode_varint(value) -- For large values (> 32 bits), convert to {high, low} and use bit64 local low_32 = value % 0x100000000 local high_32 = math.floor(value / 0x100000000) - --- @type Int64HighLow - local v = { high_32, low_32 } + local v = bit64.new(high_32, low_32) local bytes = {} repeat @@ -172,10 +171,12 @@ end --- @return string bytes The encoded 8-byte sequence. function Protobuf.encode_fixed64(value) local high, low - if type(value) == "table" then + if bit64.is_int64(value) then + --- @cast value Int64HighLow high, low = value[1], value[2] else - low = value % 0x100000000 + --- @cast value -Int64HighLow + low = math.floor(value % 0x100000000) high = math.floor(value / 0x100000000) end local b1 = low % 256 @@ -195,10 +196,11 @@ end --- @return Int64HighLow value The decoded 64-bit value as {high_32, low_32}. --- @return integer new_pos The new position in the buffer after decoding. function Protobuf.decode_fixed64(buffer, pos) + --- @type integer, integer, integer, integer, integer, integer, integer, integer local b1, b2, b3, b4, b5, b6, b7, b8 = string.byte(buffer, pos, pos + 7) local low = b1 + b2 * 256 + b3 * 65536 + b4 * 16777216 local high = b5 + b6 * 256 + b7 * 65536 + b8 * 16777216 - return { high, low }, pos + 8 + return bit64.new(high, low), pos + 8 end --- Encodes a floating-point number into a 4-byte IEEE 754 single-precision format. @@ -291,7 +293,7 @@ function Protobuf.encode_double(value) -- Mantissa is 52 bits, split across bytes local m = mantissa * 0x10000000000000 -- 2^52 - local m_low = m % 0x100000000 + local m_low = math.floor(m % 0x100000000) local m_high = math.floor(m / 0x100000000) % 0x100000 -- 20 bits local b1 = m_low % 256 @@ -371,13 +373,13 @@ end function Protobuf.zigzag_decode64(value) -- (n >>> 1) ^ -(n & 1) local shifted = bit64.shr(value, 1) - local sign_bit = { 0, bit32.band(value[2], 1) } - -- Negate: if sign_bit is 1, result is {0xFFFFFFFF, 0xFFFFFFFF}, else {0, 0} + local sign_bit = bit64.new(0, bit32.band(value[2], 1)) + -- Negate: if sign_bit is 1, result is all 1s, else all 0s local neg_sign if sign_bit[2] == 1 then - neg_sign = { 0xFFFFFFFF, 0xFFFFFFFF } + neg_sign = bit64.new(0xFFFFFFFF, 0xFFFFFFFF) else - neg_sign = { 0, 0 } + neg_sign = bit64.new(0, 0) end return bit64.bxor(shifted, neg_sign) end @@ -412,17 +414,18 @@ function Protobuf.encode(protoSchema, messageSchema, message) local values = message[field.name] if values ~= nil then if field.repeated then - if not IsList(values) or bit64.isInt64(values) then + if not is_list(values) or bit64.is_int64(values) then error("Field '" .. field.name .. "' is repeated but received a non-list value.") end else - if IsList(values) and not bit64.isInt64(values) then + if is_list(values) and not bit64.is_int64(values) then error("Field '" .. field.name .. "' is not repeated but received a list.") end values = { values } -- Wrap single value in a list for uniform processing end for _, value in ipairs(values) do -- Compute the key (field number and wire type) + --- @cast field.wireType integer local key = bit32.lshift(field_number, 3) + field.wireType buffer = buffer .. Protobuf.encode_varint(key) @@ -464,7 +467,7 @@ function Protobuf.encode(protoSchema, messageSchema, message) ) end -- For nested messages - local nested_message = Protobuf.encode(protoSchema, field.subschema, value) + local nested_message = Protobuf.encode(protoSchema, protoSchema.Message[field.subschema], value) buffer = buffer .. Protobuf.encode_length_delimited(nested_message) end else @@ -596,7 +599,7 @@ end --- Warning: Values exceeding 53-bit precision will lose precision. --- @param value Int64HighLow The {high_32, low_32} pair. --- @param strict? boolean If true, errors when value exceeds 53-bit precision. ---- @return number result The value as a Lua number (may lose precision for large values unless strict). +--- @return integer result The value as a Lua number (may lose precision for large values unless strict). function Protobuf.int64_to_number(value, strict) return bit64.to_number(value, strict) end @@ -632,7 +635,10 @@ function Protobuf.selftest() local passed = 0 local failed = 0 - --- Helper to convert string to hex for display + -- ============================================================================ + -- TEST HELPERS + -- ============================================================================ + local function to_hex(s) local hex = {} for i = 1, #s do @@ -641,20 +647,16 @@ function Protobuf.selftest() return table.concat(hex, " ") end - --- Assert helper local function assert_eq(actual, expected, msg) if actual == expected then passed = passed + 1 print(" PASS: " .. msg) - return true else failed = failed + 1 print(" FAIL: " .. msg .. ": expected " .. tostring(expected) .. ", got " .. tostring(actual)) - return false end end - --- Assert helper for bytes local function assert_bytes(actual, expected_hex, msg) local expected = "" for byte in expected_hex:gmatch("%x%x") do @@ -663,315 +665,535 @@ function Protobuf.selftest() if actual == expected then passed = passed + 1 print(" PASS: " .. msg) - return true else failed = failed + 1 print(" FAIL: " .. msg .. ": expected " .. expected_hex .. ", got " .. to_hex(actual)) - return false end end - --- Assert helper for floats local function assert_close(actual, expected, epsilon, msg) if math.abs(actual - expected) <= epsilon then passed = passed + 1 print(" PASS: " .. msg) - return true else failed = failed + 1 print(" FAIL: " .. msg .. ": expected " .. tostring(expected) .. ", got " .. tostring(actual)) - return false end end - --- Assert helper for {high, low} pairs - local function assert_int64_eq(actual, expected_high, expected_low, msg) - if type(actual) == "table" and actual[1] == expected_high and actual[2] == expected_low then + local function assert_int64(actual, high, low, msg) + if type(actual) == "table" and actual[1] == high and actual[2] == low then passed = passed + 1 print(" PASS: " .. msg) - return true else failed = failed + 1 - local actual_str = type(actual) == "table" and string.format("{%d, %d}", actual[1], actual[2]) or tostring(actual) - print(" FAIL: " .. msg .. ": expected {" .. expected_high .. ", " .. expected_low .. "}, got " .. actual_str) - return false + local actual_str = type(actual) == "table" and string.format("{0x%X, 0x%X}", actual[1], actual[2]) + or tostring(actual) + print(" FAIL: " .. msg .. string.format(": expected {0x%X, 0x%X}, got %s", high, low, actual_str)) end end - -- ========================================== - -- VARINT ENCODING TESTS (from official spec) - -- ========================================== - - -- Test: 0 -> 0x00 - assert_bytes(Protobuf.encode_varint(0), "00", "varint(0)") - - -- Test: 1 -> 0x01 - assert_bytes(Protobuf.encode_varint(1), "01", "varint(1)") - - -- Test: 127 -> 0x7F (max single byte) - assert_bytes(Protobuf.encode_varint(127), "7F", "varint(127)") - - -- Test: 128 -> 0x80 0x01 (first two-byte value) - assert_bytes(Protobuf.encode_varint(128), "8001", "varint(128)") + local function assert_error(fn, pattern, msg) + local ok, err = pcall(fn) + if not ok and type(err) == "string" and string.find(err, pattern) then + passed = passed + 1 + print(" PASS: " .. msg) + else + failed = failed + 1 + print(" FAIL: " .. msg .. (ok and " (no error thrown)" or ": " .. tostring(err))) + end + end - -- Test: 150 -> 0x96 0x01 (from official docs) - assert_bytes(Protobuf.encode_varint(150), "9601", "varint(150)") + -- ============================================================================ + -- TEST SCHEMA (shared across encode/decode tests) + -- ============================================================================ - -- Test: 300 -> 0xAC 0x02 - assert_bytes(Protobuf.encode_varint(300), "AC02", "varint(300)") + --- @type ProtoSchema + local Schema = { + WireType = { VARINT = 0, FIXED64 = 1, LENGTH_DELIMITED = 2, FIXED32 = 5 }, + DataType = { + DOUBLE = 1, + FLOAT = 2, + INT64 = 3, + UINT64 = 4, + INT32 = 5, + FIXED64 = 6, + FIXED32 = 7, + BOOL = 8, + STRING = 9, + MESSAGE = 11, + BYTES = 12, + UINT32 = 13, + ENUM = 14, + SFIXED32 = 15, + SFIXED64 = 16, + SINT32 = 17, + SINT64 = 18, + }, + Enum = {}, + Message = {}, + RPC = {}, + } - -- Test: 16383 -> 0xFF 0x7F (max two-byte) - assert_bytes(Protobuf.encode_varint(16383), "FF7F", "varint(16383)") + -- Helper to create simple single-field message schemas + local function make_schema(name, field_name, data_type, wire_type, opts) + opts = opts or {} + return { + name = name, + options = {}, + fields = { + [1] = { + name = field_name, + type = data_type, + wireType = wire_type, + repeated = opts.repeated, + subschema = opts.subschema, + }, + }, + } + end - -- Test: 16384 -> 0x80 0x80 0x01 (first three-byte) - assert_bytes(Protobuf.encode_varint(16384), "808001", "varint(16384)") + -- ============================================================================ + -- VERSION + -- ============================================================================ + + assert_eq(type(Protobuf.version()), "string", "version() returns string") + + -- ============================================================================ + -- VARINT ENCODING (official protobuf spec test vectors) + -- ============================================================================ + + local varint_vectors = { + { 0, "00" }, + { 1, "01" }, + { 127, "7F" }, + { 128, "8001" }, + { 150, "9601" }, -- from official docs + { 300, "AC02" }, + { 16383, "FF7F" }, -- max 2-byte + { 16384, "808001" }, -- first 3-byte + } + for _, t in ipairs(varint_vectors) do + assert_bytes(Protobuf.encode_varint(t[1]), t[2], "varint encode " .. t[1]) + end - -- Test roundtrip for various values - for _, v in ipairs({ 0, 1, 127, 128, 150, 300, 16383, 16384, 65535, 2097151, 268435455 }) do + -- Varint roundtrip + for _, v in ipairs({ 0, 1, 127, 128, 150, 300, 16383, 16384, 65535, 2097151, 268435455, 2 ^ 40 + 12345 }) do local enc = Protobuf.encode_varint(v) - local dec, _ = Protobuf.decode_varint(enc, 1) + local dec = Protobuf.decode_varint(enc, 1) assert_eq(dec, v, "varint roundtrip " .. v) end - -- ========================================== - -- VARINT64 TESTS - -- ========================================== - - -- Test: Large value that fits in 53 bits - local large_val = 2 ^ 40 + 12345 - local enc64 = Protobuf.encode_varint(large_val) - local dec64, _ = Protobuf.decode_varint(enc64, 1) - assert_eq(dec64, large_val, "varint64 roundtrip 2^40+12345") - - -- Test: decode_varint64 always returns {high, low} - local result64, _ = Protobuf.decode_varint64(Protobuf.encode_varint(150), 1) - assert_int64_eq(result64, 0, 150, "decode_varint64(150)") - - -- Test: {high, low} encoding - local hl_val = { 0x12345678, 0x9ABCDEF0 } - local hl_enc = Protobuf.encode_varint(hl_val) - local hl_dec, _ = Protobuf.decode_varint64(hl_enc, 1) - assert_int64_eq(hl_dec, hl_val[1], hl_val[2], "varint {high,low} roundtrip") - - -- ========================================== - -- FIXED32 TESTS - -- ========================================== - - -- Test: 0 -> 00 00 00 00 - assert_bytes(Protobuf.encode_fixed32(0), "00000000", "fixed32(0)") - - -- Test: 1 -> 01 00 00 00 (little-endian) - assert_bytes(Protobuf.encode_fixed32(1), "01000000", "fixed32(1)") - - -- Test: 12345 -> 39 30 00 00 (little-endian: 0x3039) - assert_bytes(Protobuf.encode_fixed32(12345), "39300000", "fixed32(12345)") - - -- Test: max uint32 -> FF FF FF FF - assert_bytes(Protobuf.encode_fixed32(0xFFFFFFFF), "FFFFFFFF", "fixed32(max)") - - -- Roundtrip - for _, v in ipairs({ 0, 1, 12345, 0xFFFFFFFF, 1234567890 }) do - local enc = Protobuf.encode_fixed32(v) - local dec, _ = Protobuf.decode_fixed32(enc, 1) - assert_eq(dec, v, "fixed32 roundtrip " .. v) + -- Varint64 with Int64 values + local v64 = bit64.new(0x12345678, 0x9ABCDEF0) + local enc64 = Protobuf.encode_varint(v64) + local dec64 = Protobuf.decode_varint64(enc64, 1) + assert_int64(dec64, v64[1], v64[2], "varint64 Int64 roundtrip") + assert_eq(bit64.is_int64(dec64), true, "decode_varint64 returns marked Int64") + + -- ============================================================================ + -- FIXED32 ENCODING + -- ============================================================================ + + local fixed32_vectors = { + { 0, "00000000" }, + { 1, "01000000" }, + { 12345, "39300000" }, + { 0xFFFFFFFF, "FFFFFFFF" }, + } + for _, t in ipairs(fixed32_vectors) do + assert_bytes(Protobuf.encode_fixed32(t[1]), t[2], "fixed32 encode " .. t[1]) + local dec = Protobuf.decode_fixed32(Protobuf.encode_fixed32(t[1]), 1) + assert_eq(dec, t[1], "fixed32 roundtrip " .. t[1]) end - -- ========================================== - -- FIXED64 TESTS - -- ========================================== + -- ============================================================================ + -- FIXED64 ENCODING + -- ============================================================================ - -- Test: 0 -> 00 00 00 00 00 00 00 00 - assert_bytes(Protobuf.encode_fixed64({ 0, 0 }), "0000000000000000", "fixed64(0)") - - -- Test: 1 -> 01 00 00 00 00 00 00 00 - assert_bytes(Protobuf.encode_fixed64({ 0, 1 }), "0100000000000000", "fixed64(1)") - - -- Roundtrip with {high, low} - local test_vals = { - { 0, 0 }, - { 0, 1 }, - { 0, 0xFFFFFFFF }, - { 1, 0 }, - { 0xFFFFFFFF, 0xFFFFFFFF }, - { 0x12345678, 0x9ABCDEF0 }, + local fixed64_vectors = { + { bit64.new(0, 0), "0000000000000000" }, + { bit64.new(0, 1), "0100000000000000" }, + { bit64.new(0, 0xFFFFFFFF), "FFFFFFFF00000000" }, + { bit64.new(1, 0), "0000000001000000" }, + { bit64.new(0xFFFFFFFF, 0xFFFFFFFF), "FFFFFFFFFFFFFFFF" }, } - for _, v in ipairs(test_vals) do - local enc = Protobuf.encode_fixed64(v) - local dec, _ = Protobuf.decode_fixed64(enc, 1) - assert_int64_eq(dec, v[1], v[2], string.format("fixed64 roundtrip {0x%X, 0x%X}", v[1], v[2])) + for _, t in ipairs(fixed64_vectors) do + assert_bytes(Protobuf.encode_fixed64(t[1]), t[2], string.format("fixed64 encode {0x%X, 0x%X}", t[1][1], t[1][2])) + local dec = Protobuf.decode_fixed64(Protobuf.encode_fixed64(t[1]), 1) + assert_int64(dec, t[1][1], t[1][2], string.format("fixed64 roundtrip {0x%X, 0x%X}", t[1][1], t[1][2])) end - -- ========================================== - -- FLOAT TESTS - -- ========================================== + -- ============================================================================ + -- FLOAT ENCODING + -- ============================================================================ - -- Test: 0.0 -> 00 00 00 00 - assert_bytes(Protobuf.encode_float(0), "00000000", "float(0)") + assert_bytes(Protobuf.encode_float(0), "00000000", "float encode 0") + assert_bytes(Protobuf.encode_float(1.0), "0000803F", "float encode 1.0") - -- Test: 1.0 -> 00 00 80 3F (IEEE 754: 0x3F800000) - assert_bytes(Protobuf.encode_float(1.0), "0000803F", "float(1.0)") - - -- Roundtrip - for _, v in ipairs({ 0.0, 1.0, -1.0, 3.14159, 100.5, -1234.5678 }) do - local enc = Protobuf.encode_float(v) - local dec, _ = Protobuf.decode_float(enc, 1) + for _, v in ipairs({ 0.0, 1.0, -1.0, 3.14159, 100.5, -1234.5678, 1e10, -1e-10 }) do + local dec = Protobuf.decode_float(Protobuf.encode_float(v), 1) assert_close(dec, v, 1e-4, "float roundtrip " .. v) end - -- ========================================== - -- DOUBLE TESTS - -- ========================================== - - -- Test: 0.0 -> 00 00 00 00 00 00 00 00 - assert_bytes(Protobuf.encode_double(0), "0000000000000000", "double(0)") + -- ============================================================================ + -- DOUBLE ENCODING + -- ============================================================================ - -- Test: 1.0 -> 00 00 00 00 00 00 F0 3F (IEEE 754: 0x3FF0000000000000) - assert_bytes(Protobuf.encode_double(1.0), "000000000000F03F", "double(1.0)") + assert_bytes(Protobuf.encode_double(0), "0000000000000000", "double encode 0") + assert_bytes(Protobuf.encode_double(1.0), "000000000000F03F", "double encode 1.0") - -- Roundtrip for _, v in ipairs({ 0.0, 1.0, -1.0, 3.141592653589793, 1e100, -1e-100 }) do - local enc = Protobuf.encode_double(v) - local dec, _ = Protobuf.decode_double(enc, 1) + local dec = Protobuf.decode_double(Protobuf.encode_double(v), 1) assert_close(dec, v, 1e-10, "double roundtrip " .. v) end - -- ========================================== - -- ZIGZAG TESTS (from official spec) - -- ========================================== + -- ============================================================================ + -- ZIGZAG ENCODING (official protobuf spec test vectors) + -- ============================================================================ - -- Official test vectors from protobuf spec - local zigzag32_tests = { + local zigzag32_vectors = { { 0, 0 }, { -1, 1 }, { 1, 2 }, { -2, 3 }, - { 2147483647, 4294967294 }, -- 0x7FFFFFFF -> 0xFFFFFFFE - { -2147483648, 4294967295 }, -- 0x80000000 -> 0xFFFFFFFF + { 2147483647, 4294967294 }, + { -2147483648, 4294967295 }, } - for _, test in ipairs(zigzag32_tests) do - local input, expected = test[1], test[2] - local actual = Protobuf.zigzag_encode32(input) - -- Handle signed result from bit.bxor - if actual < 0 then - actual = actual + 0x100000000 + for _, t in ipairs(zigzag32_vectors) do + local enc = Protobuf.zigzag_encode32(t[1]) + if enc < 0 then + enc = enc + 0x100000000 end - assert_eq(actual, expected, string.format("zigzag_encode32(%d)", input)) + assert_eq(enc, t[2], "zigzag32 encode " .. t[1]) end - -- Roundtrip zigzag32 for _, v in ipairs({ 0, 1, -1, 100, -100, 2147483647, -2147483648 }) do - local enc = Protobuf.zigzag_encode32(v) - local dec = Protobuf.zigzag_decode32(enc) + local dec = Protobuf.zigzag_decode32(Protobuf.zigzag_encode32(v)) assert_eq(dec, v, "zigzag32 roundtrip " .. v) end - -- Roundtrip zigzag64 - local zigzag64_tests = { - { { 0, 0 }, { 0, 0 } }, -- 0 -> 0 - { { 0xFFFFFFFF, 0xFFFFFFFF }, { 0, 1 } }, -- -1 -> 1 - { { 0, 1 }, { 0, 2 } }, -- 1 -> 2 + local zigzag64_vectors = { + { bit64.new(0, 0), bit64.new(0, 0) }, + { bit64.new(0xFFFFFFFF, 0xFFFFFFFF), bit64.new(0, 1) }, + { bit64.new(0, 1), bit64.new(0, 2) }, + } + for _, t in ipairs(zigzag64_vectors) do + local enc = Protobuf.zigzag_encode64(t[1]) + assert_int64(enc, t[2][1], t[2][2], string.format("zigzag64 encode {0x%X, 0x%X}", t[1][1], t[1][2])) + end + + local zigzag64_roundtrip = { + bit64.new(0, 0), + bit64.new(0xFFFFFFFF, 0xFFFFFFFF), + bit64.new(0, 1), + bit64.new(0xFFFFFFFF, 0xFFFFFFFE), + bit64.new(0x7FFFFFFF, 0xFFFFFFFF), } - for _, test in ipairs(zigzag64_tests) do - local input, expected = test[1], test[2] - local actual = Protobuf.zigzag_encode64(input) - assert_int64_eq(actual, expected[1], expected[2], string.format("zigzag_encode64({%d,%d})", input[1], input[2])) + for _, v in ipairs(zigzag64_roundtrip) do + local dec = Protobuf.zigzag_decode64(Protobuf.zigzag_encode64(v)) + assert_int64(dec, v[1], v[2], string.format("zigzag64 roundtrip {0x%X, 0x%X}", v[1], v[2])) end - -- ========================================== - -- LENGTH-DELIMITED TESTS - -- ========================================== + -- ============================================================================ + -- LENGTH-DELIMITED ENCODING + -- ============================================================================ - -- Test: "testing" from official docs (12 07 74 65 73 74 69 6e 67) - local test_str = "testing" - local enc_str = Protobuf.encode_length_delimited(test_str) - assert_bytes(enc_str, "0774657374696E67", "length_delimited('testing')") + assert_bytes(Protobuf.encode_length_delimited("testing"), "0774657374696E67", "length_delimited 'testing'") - -- Roundtrip - local dec_str, _ = Protobuf.decode_length_delimited(enc_str, 1) - assert_eq(dec_str, test_str, "length_delimited roundtrip") + for _, s in ipairs({ "", "hello", string.rep("x", 1000) }) do + local dec = Protobuf.decode_length_delimited(Protobuf.encode_length_delimited(s), 1) + assert_eq(dec, s, "length_delimited roundtrip len=" .. #s) + end - -- ========================================== - -- INT64 UTILITY TESTS - -- ========================================== + -- ============================================================================ + -- INT64 UTILITIES + -- ============================================================================ - -- to_hex - local hex = Protobuf.int64_to_hex({ 0x12345678, 0x9ABCDEF0 }) - assert_eq(hex, "123456789ABCDEF0", "int64_to_hex") + assert_eq(Protobuf.int64_to_hex({ 0x12345678, 0x9ABCDEF0 }), "123456789ABCDEF0", "int64_to_hex") - -- to_number / from_number roundtrip local num = 123456789012345 - local hl = Protobuf.int64_from_number(num) - local back = Protobuf.int64_to_number(hl) - assert_eq(back, num, "int64_from_number/to_number roundtrip") + assert_eq(Protobuf.int64_to_number(Protobuf.int64_from_number(num)), num, "int64 from/to number roundtrip") + assert_eq(bit64.is_int64(Protobuf.int64_from_number(num)), true, "int64_from_number returns marked Int64") - -- equals assert_eq(Protobuf.int64_equals({ 1, 2 }, { 1, 2 }), true, "int64_equals same") assert_eq(Protobuf.int64_equals({ 1, 2 }, { 1, 3 }), false, "int64_equals diff") - - -- is_zero assert_eq(Protobuf.int64_is_zero({ 0, 0 }), true, "int64_is_zero true") assert_eq(Protobuf.int64_is_zero({ 0, 1 }), false, "int64_is_zero false") - -- ========================================== - -- INT64 METATABLE MARKER TESTS - -- ========================================== + assert_error(function() + Protobuf.int64_to_number(bit64.new(0x00200000, 0), true) + end, "53%-bit", "int64_to_number strict mode rejects >53-bit") - -- Test: decode_varint64 returns marked Int64 values - local decoded_int64, _ = Protobuf.decode_varint64(Protobuf.encode_varint(12345), 1) - assert_eq(bit64.isInt64(decoded_int64), true, "decode_varint64 returns marked Int64") - - -- Test: int64_from_number returns marked Int64 values - local from_num = Protobuf.int64_from_number(9876543210) - assert_eq(bit64.isInt64(from_num), true, "int64_from_number returns marked Int64") - - -- Test: IsList correctly distinguishes Int64 from arrays + -- Int64 vs array distinction local int64_val = bit64.new(0, 42) local array_val = { 1, 2 } - assert_eq(IsList(int64_val), true, "IsList sees Int64 as list-like (2 elements)") - assert_eq(bit64.isInt64(int64_val), true, "bit64.isInt64 identifies Int64") - assert_eq(bit64.isInt64(array_val), false, "bit64.isInt64 rejects plain array") + assert_eq(bit64.is_int64(int64_val), true, "bit64.is_int64 identifies Int64") + assert_eq(bit64.is_int64(array_val), false, "bit64.is_int64 rejects plain array") + + -- ============================================================================ + -- MESSAGE ENCODE/DECODE: SCALAR TYPES + -- ============================================================================ + + -- Bool + local boolSchema = make_schema("Bool", "flag", Schema.DataType.BOOL, Schema.WireType.VARINT) + for _, v in ipairs({ true, false }) do + local dec = Protobuf.decode(Schema, boolSchema, Protobuf.encode(Schema, boolSchema, { flag = v })) + assert_eq(dec.flag, v, "encode/decode bool " .. tostring(v)) + end - -- Test: Encode with uint64 field using Int64 value doesn't error - -- Create a minimal schema for testing - local testSchema = { - WireType = { VARINT = 0, FIXED64 = 1, LENGTH_DELIMITED = 2, FIXED32 = 5 }, - DataType = { UINT64 = 4 }, - Message = {}, + -- Int32 + local int32Schema = make_schema("Int32", "value", Schema.DataType.INT32, Schema.WireType.VARINT) + for _, v in ipairs({ 0, 1, 127, 128, 65535, 2147483647 }) do + local dec = Protobuf.decode(Schema, int32Schema, Protobuf.encode(Schema, int32Schema, { value = v })) + assert_eq(dec.value, v, "encode/decode int32 " .. v) + end + + -- Uint64 + local uint64Schema = make_schema("Uint64", "value", Schema.DataType.UINT64, Schema.WireType.VARINT) + local u64 = bit64.new(0x00001800, 0x00001000) + local decU64 = Protobuf.decode(Schema, uint64Schema, Protobuf.encode(Schema, uint64Schema, { value = u64 })) + assert_int64(decU64.value, u64[1], u64[2], "encode/decode uint64") + + -- Sint32 (zigzag) + local sint32Schema = make_schema("Sint32", "value", Schema.DataType.SINT32, Schema.WireType.VARINT) + for _, v in ipairs({ 0, 1, -1, 100, -100, 2147483647, -2147483648 }) do + local dec = Protobuf.decode(Schema, sint32Schema, Protobuf.encode(Schema, sint32Schema, { value = v })) + assert_eq(dec.value, v, "encode/decode sint32 " .. v) + end + + -- Sint64 (zigzag) + local sint64Schema = make_schema("Sint64", "value", Schema.DataType.SINT64, Schema.WireType.VARINT) + for _, t in ipairs({ { bit64.new(0, 0), "0" }, { bit64.new(0, 1), "1" }, { bit64.new(0xFFFFFFFF, 0xFFFFFFFF), "-1" } }) do + local dec = Protobuf.decode(Schema, sint64Schema, Protobuf.encode(Schema, sint64Schema, { value = t[1] })) + assert_int64(dec.value, t[1][1], t[1][2], "encode/decode sint64 " .. t[2]) + end + + -- String + local stringSchema = make_schema("String", "text", Schema.DataType.STRING, Schema.WireType.LENGTH_DELIMITED) + for _, v in ipairs({ "", "hello", "unicode: \xC3\xA9", string.rep("x", 1000) }) do + local dec = Protobuf.decode(Schema, stringSchema, Protobuf.encode(Schema, stringSchema, { text = v })) + assert_eq(dec.text, v, "encode/decode string len=" .. #v) + end + + -- Float + local floatSchema = make_schema("Float", "value", Schema.DataType.FLOAT, Schema.WireType.FIXED32) + for _, v in ipairs({ 0.0, 1.0, -1.0, 3.14159, 1e10 }) do + local dec = Protobuf.decode(Schema, floatSchema, Protobuf.encode(Schema, floatSchema, { value = v })) + assert_close(dec.value, v, 1e-4, "encode/decode float " .. v) + end + + -- Double + local doubleSchema = make_schema("Double", "value", Schema.DataType.DOUBLE, Schema.WireType.FIXED64) + for _, v in ipairs({ 0.0, 1.0, -1.0, 3.141592653589793, 1e100 }) do + local dec = Protobuf.decode(Schema, doubleSchema, Protobuf.encode(Schema, doubleSchema, { value = v })) + assert_close(dec.value, v, 1e-10, "encode/decode double " .. v) + end + + -- Fixed32 + local fixed32FieldSchema = make_schema("Fixed32", "value", Schema.DataType.FIXED32, Schema.WireType.FIXED32) + for _, v in ipairs({ 0, 1, 255, 0xFFFFFFFF }) do + local dec = Protobuf.decode(Schema, fixed32FieldSchema, Protobuf.encode(Schema, fixed32FieldSchema, { value = v })) + assert_eq(dec.value, v, "encode/decode fixed32 " .. v) + end + + -- Fixed64 + local fixed64FieldSchema = make_schema("Fixed64", "value", Schema.DataType.FIXED64, Schema.WireType.FIXED64) + for _, v in ipairs({ bit64.new(0, 0), bit64.new(0, 1), bit64.new(0xFFFFFFFF, 0xFFFFFFFF) }) do + local dec = Protobuf.decode(Schema, fixed64FieldSchema, Protobuf.encode(Schema, fixed64FieldSchema, { value = v })) + assert_int64(dec.value, v[1], v[2], string.format("encode/decode fixed64 {0x%X, 0x%X}", v[1], v[2])) + end + + -- Enum + local enumSchema = make_schema("Enum", "status", Schema.DataType.ENUM, Schema.WireType.VARINT) + for _, v in ipairs({ 0, 1, 2, 100 }) do + local dec = Protobuf.decode(Schema, enumSchema, Protobuf.encode(Schema, enumSchema, { status = v })) + assert_eq(dec.status, v, "encode/decode enum " .. v) + end + + -- ============================================================================ + -- MESSAGE ENCODE/DECODE: REPEATED FIELDS + -- ============================================================================ + + local repeatedSchema = + make_schema("Repeated", "values", Schema.DataType.INT32, Schema.WireType.VARINT, { repeated = true }) + local vals = { 1, 2, 3, 100, 200 } + local decR = Protobuf.decode(Schema, repeatedSchema, Protobuf.encode(Schema, repeatedSchema, { values = vals })) + assert_eq(#decR.values, #vals, "repeated field count") + for i, v in ipairs(vals) do + assert_eq(decR.values[i], v, "repeated field[" .. i .. "]") + end + + -- ============================================================================ + -- MESSAGE ENCODE/DECODE: NESTED MESSAGES + -- ============================================================================ + + -- Single level nesting + Schema.Message["Inner"] = { + name = "Inner", + options = {}, + fields = { + [1] = { name = "id", type = Schema.DataType.INT32, wireType = Schema.WireType.VARINT }, + }, } - local testMessageSchema = { - name = "TestMessage", + local outerSchema = + make_schema("Outer", "inner", Schema.DataType.MESSAGE, Schema.WireType.LENGTH_DELIMITED, { subschema = "Inner" }) + local decO = Protobuf.decode(Schema, outerSchema, Protobuf.encode(Schema, outerSchema, { inner = { id = 42 } })) + assert_eq(decO.inner.id, 42, "nested message single level") + + -- Deep nesting (3 levels) + Schema.Message["L3"] = { + name = "L3", + options = {}, fields = { - [1] = { - name = "address", - type = testSchema.DataType.UINT64, - wireType = testSchema.WireType.VARINT, - repeated = false, + [1] = { name = "val", type = Schema.DataType.STRING, wireType = Schema.WireType.LENGTH_DELIMITED }, + }, + } + Schema.Message["L2"] = { + name = "L2", + options = {}, + fields = { + [1] = { name = "name", type = Schema.DataType.STRING, wireType = Schema.WireType.LENGTH_DELIMITED }, + [2] = { + name = "child", + type = Schema.DataType.MESSAGE, + wireType = Schema.WireType.LENGTH_DELIMITED, + subschema = "L3", }, }, } + local l1Schema = { + name = "L1", + options = {}, + fields = { + [1] = { name = "id", type = Schema.DataType.INT32, wireType = Schema.WireType.VARINT }, + [2] = { + name = "child", + type = Schema.DataType.MESSAGE, + wireType = Schema.WireType.LENGTH_DELIMITED, + subschema = "L2", + }, + }, + } + local deepMsg = { id = 100, child = { name = "level2", child = { val = "deepest" } } } + local decD = Protobuf.decode(Schema, l1Schema, Protobuf.encode(Schema, l1Schema, deepMsg)) + assert_eq(decD.id, 100, "deep nested level1.id") + assert_eq(decD.child.name, "level2", "deep nested level2.name") + assert_eq(decD.child.child.val, "deepest", "deep nested level3.val") + + -- Repeated nested messages + Schema.Message["Item"] = { + name = "Item", + options = {}, + fields = { + [1] = { name = "name", type = Schema.DataType.STRING, wireType = Schema.WireType.LENGTH_DELIMITED }, + [2] = { name = "qty", type = Schema.DataType.INT32, wireType = Schema.WireType.VARINT }, + }, + } + local orderSchema = { + name = "Order", + options = {}, + fields = { + [1] = { name = "id", type = Schema.DataType.INT32, wireType = Schema.WireType.VARINT }, + [2] = { + name = "items", + type = Schema.DataType.MESSAGE, + wireType = Schema.WireType.LENGTH_DELIMITED, + subschema = "Item", + repeated = true, + }, + }, + } + local orderMsg = { id = 123, items = { { name = "A", qty = 5 }, { name = "B", qty = 3 } } } + local decOrder = Protobuf.decode(Schema, orderSchema, Protobuf.encode(Schema, orderSchema, orderMsg)) + assert_eq(decOrder.id, 123, "repeated nested order.id") + assert_eq(#decOrder.items, 2, "repeated nested item count") + assert_eq(decOrder.items[1].name, "A", "repeated nested item[1].name") + assert_eq(decOrder.items[2].qty, 3, "repeated nested item[2].qty") + + -- Multiple nested message fields + Schema.Message["Addr"] = { + name = "Addr", + options = {}, + fields = { + [1] = { name = "street", type = Schema.DataType.STRING, wireType = Schema.WireType.LENGTH_DELIMITED }, + [2] = { name = "city", type = Schema.DataType.STRING, wireType = Schema.WireType.LENGTH_DELIMITED }, + }, + } + local personSchema = { + name = "Person", + options = {}, + fields = { + [1] = { name = "name", type = Schema.DataType.STRING, wireType = Schema.WireType.LENGTH_DELIMITED }, + [2] = { + name = "home", + type = Schema.DataType.MESSAGE, + wireType = Schema.WireType.LENGTH_DELIMITED, + subschema = "Addr", + }, + [3] = { + name = "work", + type = Schema.DataType.MESSAGE, + wireType = Schema.WireType.LENGTH_DELIMITED, + subschema = "Addr", + }, + }, + } + local personMsg = + { name = "Alice", home = { street = "123 Home", city = "H" }, work = { street = "456 Work", city = "W" } } + local decP = Protobuf.decode(Schema, personSchema, Protobuf.encode(Schema, personSchema, personMsg)) + assert_eq(decP.name, "Alice", "multi-nested person.name") + assert_eq(decP.home.city, "H", "multi-nested home.city") + assert_eq(decP.work.city, "W", "multi-nested work.city") + + -- ============================================================================ + -- EDGE CASES + -- ============================================================================ + + -- Empty message + local emptyDec = Protobuf.decode(Schema, int32Schema, Protobuf.encode(Schema, int32Schema, {})) + assert_eq(emptyDec.value, nil, "empty message has nil field") + + -- Unknown field skipping + local twoField = { + name = "Two", + options = {}, + fields = { + [1] = { name = "a", type = Schema.DataType.INT32, wireType = Schema.WireType.VARINT }, + [2] = { name = "b", type = Schema.DataType.INT32, wireType = Schema.WireType.VARINT }, + }, + } + local oneField = { + name = "One", + options = {}, + fields = { + [1] = { name = "a", type = Schema.DataType.INT32, wireType = Schema.WireType.VARINT }, + }, + } + local enc2 = Protobuf.encode(Schema, twoField, { a = 123, b = 456 }) + local dec1 = Protobuf.decode(Schema, oneField, enc2) + assert_eq(dec1.a, 123, "unknown field skipped, known preserved") - -- This should NOT error - Int64 values should not be treated as repeated fields - local success, err = pcall(function() - local int64_address = bit64.new(0, 0x12345678) - Protobuf.encode(testSchema, testMessageSchema, { address = int64_address }) - end) - if success then - passed = passed + 1 - print(" PASS: encode with Int64 uint64 field") - else - failed = failed + 1 - print(" FAIL: encode with Int64 uint64 field: " .. tostring(err)) - end + -- ============================================================================ + -- ERROR HANDLING + -- ============================================================================ + + assert_error(function() + Protobuf.encode(Schema, repeatedSchema, { values = 123 }) + end, "non%-list", "error: non-list for repeated field") + + assert_error(function() + Protobuf.encode(Schema, int32Schema, { value = { 1, 2, 3 } }) + end, "not repeated", "error: list for non-repeated field") - -- Test: Encode/decode roundtrip with uint64 field - local original_address = bit64.new(0x00001800, 0x00001000) - local encoded = Protobuf.encode(testSchema, testMessageSchema, { address = original_address }) - local decoded, _ = Protobuf.decode(testSchema, testMessageSchema, encoded) - assert_int64_eq(decoded.address, original_address[1], original_address[2], "uint64 field encode/decode roundtrip") + local noSubSchema = make_schema("NoSub", "nested", Schema.DataType.MESSAGE, Schema.WireType.LENGTH_DELIMITED) + assert_error(function() + Protobuf.encode(Schema, noSubSchema, { nested = { foo = 1 } }) + end, "no subschema", "error: nested message without subschema") - -- ========================================== + -- ============================================================================ -- SUMMARY - -- ========================================== + -- ============================================================================ print(string.format("\nProtobuf operations: %d/%d tests passed\n", passed, passed + failed)) return failed == 0 end diff --git a/src/protobuf/types.lua b/src/protobuf/types.lua new file mode 100644 index 0000000..6643e1a --- /dev/null +++ b/src/protobuf/types.lua @@ -0,0 +1,70 @@ +-- Generated Lua schema from protobuf descriptor set +-- Do not edit manually + +--- @class ProtoSchema +local ProtoSchema = {} + +--- Maps enum names to their definitions. +ProtoSchema.Enum = {} + +--- Maps message names to their definitions. +--- @type table +ProtoSchema.Message = {} + +--- Maps service names to their method definitions. +--- @type table +ProtoSchema.RPC = {} + +--- ProtoWireType Maps protobuf wire types to their integer values. +--- @enum ProtoWireType +ProtoSchema.WireType = { + VARINT = 0, + FIXED64 = 1, + LENGTH_DELIMITED = 2, + FIXED32 = 5, +} + +--- ProtoDataType Maps protobuf data types to their integer values. +--- @enum ProtoDataType +ProtoSchema.DataType = { + DOUBLE = 1, + FLOAT = 2, + INT64 = 3, + UINT64 = 4, + INT32 = 5, + FIXED64 = 6, + FIXED32 = 7, + BOOL = 8, + STRING = 9, + MESSAGE = 11, + BYTES = 12, + UINT32 = 13, + ENUM = 14, + SFIXED32 = 15, + SFIXED64 = 16, + SINT32 = 17, + SINT64 = 18, +} + +--- @class ProtoFieldSchema +--- @field name string The name of the field. +--- @field wireType ProtoWireType The protobuf wire type (see ProtoSchema.WireType). +--- @field type ProtoDataType The protobuf type (see ProtoSchema.DataType). +--- @field repeated boolean? Whether the field is repeated (optional). +--- @field subschema string? The subschema name for nested messages (optional). + +--- @class ProtoMessageSchema +--- @field name string The name of the message type. +--- @field options table Message options. +--- @field fields table A map of field numbers to ProtoFieldSchema definitions. + +--- @class ProtoServiceMethodSchema +--- @field service string The name of the service. +--- @field method string The method name. +--- @field inputType ProtoMessageSchema The protobuf message type for the request. +--- @field outputType ProtoMessageSchema The protobuf message type for the response. + +--- @class ProtoServiceSchema +--- @field [string] ProtoServiceMethodSchema Maps method names to their method definitions. + +return ProtoSchema diff --git a/tools/gen_lua_proto_schema b/tools/gen_lua_proto_schema index cb0d835..90b1acf 100755 --- a/tools/gen_lua_proto_schema +++ b/tools/gen_lua_proto_schema @@ -1,6 +1,7 @@ #!/usr/bin/env python3 import argparse import os +import shutil import subprocess import tempfile from textwrap import dedent @@ -40,53 +41,53 @@ LUA_TYPE_MAP: dict[int, list[str]] = { # Lua wire type mapping WIRE_TYPE_MAP: dict[int, str] = { - FieldDescriptorProto.TYPE_DOUBLE: "PROTOBUF_SCHEMA.WireType.FIXED64", - FieldDescriptorProto.TYPE_FLOAT: "PROTOBUF_SCHEMA.WireType.FIXED32", - FieldDescriptorProto.TYPE_INT64: "PROTOBUF_SCHEMA.WireType.VARINT", - FieldDescriptorProto.TYPE_UINT64: "PROTOBUF_SCHEMA.WireType.VARINT", - FieldDescriptorProto.TYPE_INT32: "PROTOBUF_SCHEMA.WireType.VARINT", - FieldDescriptorProto.TYPE_FIXED64: "PROTOBUF_SCHEMA.WireType.FIXED64", - FieldDescriptorProto.TYPE_FIXED32: "PROTOBUF_SCHEMA.WireType.FIXED32", - FieldDescriptorProto.TYPE_BOOL: "PROTOBUF_SCHEMA.WireType.VARINT", - FieldDescriptorProto.TYPE_STRING: "PROTOBUF_SCHEMA.WireType.LENGTH_DELIMITED", - FieldDescriptorProto.TYPE_MESSAGE: "PROTOBUF_SCHEMA.WireType.LENGTH_DELIMITED", - FieldDescriptorProto.TYPE_BYTES: "PROTOBUF_SCHEMA.WireType.LENGTH_DELIMITED", - FieldDescriptorProto.TYPE_UINT32: "PROTOBUF_SCHEMA.WireType.VARINT", - FieldDescriptorProto.TYPE_ENUM: "PROTOBUF_SCHEMA.WireType.VARINT", - FieldDescriptorProto.TYPE_SFIXED32: "PROTOBUF_SCHEMA.WireType.FIXED32", - FieldDescriptorProto.TYPE_SFIXED64: "PROTOBUF_SCHEMA.WireType.FIXED64", - FieldDescriptorProto.TYPE_SINT32: "PROTOBUF_SCHEMA.WireType.VARINT", - FieldDescriptorProto.TYPE_SINT64: "PROTOBUF_SCHEMA.WireType.VARINT", + FieldDescriptorProto.TYPE_DOUBLE: "ProtoSchema.WireType.FIXED64", + FieldDescriptorProto.TYPE_FLOAT: "ProtoSchema.WireType.FIXED32", + FieldDescriptorProto.TYPE_INT64: "ProtoSchema.WireType.VARINT", + FieldDescriptorProto.TYPE_UINT64: "ProtoSchema.WireType.VARINT", + FieldDescriptorProto.TYPE_INT32: "ProtoSchema.WireType.VARINT", + FieldDescriptorProto.TYPE_FIXED64: "ProtoSchema.WireType.FIXED64", + FieldDescriptorProto.TYPE_FIXED32: "ProtoSchema.WireType.FIXED32", + FieldDescriptorProto.TYPE_BOOL: "ProtoSchema.WireType.VARINT", + FieldDescriptorProto.TYPE_STRING: "ProtoSchema.WireType.LENGTH_DELIMITED", + FieldDescriptorProto.TYPE_MESSAGE: "ProtoSchema.WireType.LENGTH_DELIMITED", + FieldDescriptorProto.TYPE_BYTES: "ProtoSchema.WireType.LENGTH_DELIMITED", + FieldDescriptorProto.TYPE_UINT32: "ProtoSchema.WireType.VARINT", + FieldDescriptorProto.TYPE_ENUM: "ProtoSchema.WireType.VARINT", + FieldDescriptorProto.TYPE_SFIXED32: "ProtoSchema.WireType.FIXED32", + FieldDescriptorProto.TYPE_SFIXED64: "ProtoSchema.WireType.FIXED64", + FieldDescriptorProto.TYPE_SINT32: "ProtoSchema.WireType.VARINT", + FieldDescriptorProto.TYPE_SINT64: "ProtoSchema.WireType.VARINT", } TYPE_MAP: dict[int, str] = { - FieldDescriptorProto.TYPE_DOUBLE: "PROTOBUF_SCHEMA.DataType.DOUBLE", - FieldDescriptorProto.TYPE_FLOAT: "PROTOBUF_SCHEMA.DataType.FLOAT", - FieldDescriptorProto.TYPE_INT64: "PROTOBUF_SCHEMA.DataType.INT64", - FieldDescriptorProto.TYPE_UINT64: "PROTOBUF_SCHEMA.DataType.UINT64", - FieldDescriptorProto.TYPE_INT32: "PROTOBUF_SCHEMA.DataType.INT32", - FieldDescriptorProto.TYPE_FIXED64: "PROTOBUF_SCHEMA.DataType.FIXED64", - FieldDescriptorProto.TYPE_FIXED32: "PROTOBUF_SCHEMA.DataType.FIXED32", - FieldDescriptorProto.TYPE_BOOL: "PROTOBUF_SCHEMA.DataType.BOOL", - FieldDescriptorProto.TYPE_STRING: "PROTOBUF_SCHEMA.DataType.STRING", - FieldDescriptorProto.TYPE_MESSAGE: "PROTOBUF_SCHEMA.DataType.MESSAGE", - FieldDescriptorProto.TYPE_BYTES: "PROTOBUF_SCHEMA.DataType.BYTES", - FieldDescriptorProto.TYPE_UINT32: "PROTOBUF_SCHEMA.DataType.UINT32", - FieldDescriptorProto.TYPE_ENUM: "PROTOBUF_SCHEMA.DataType.ENUM", - FieldDescriptorProto.TYPE_SFIXED32: "PROTOBUF_SCHEMA.DataType.SFIXED32", - FieldDescriptorProto.TYPE_SFIXED64: "PROTOBUF_SCHEMA.DataType.SFIXED64", - FieldDescriptorProto.TYPE_SINT32: "PROTOBUF_SCHEMA.DataType.SINT32", - FieldDescriptorProto.TYPE_SINT64: "PROTOBUF_SCHEMA.DataType.SINT64", + FieldDescriptorProto.TYPE_DOUBLE: "ProtoSchema.DataType.DOUBLE", + FieldDescriptorProto.TYPE_FLOAT: "ProtoSchema.DataType.FLOAT", + FieldDescriptorProto.TYPE_INT64: "ProtoSchema.DataType.INT64", + FieldDescriptorProto.TYPE_UINT64: "ProtoSchema.DataType.UINT64", + FieldDescriptorProto.TYPE_INT32: "ProtoSchema.DataType.INT32", + FieldDescriptorProto.TYPE_FIXED64: "ProtoSchema.DataType.FIXED64", + FieldDescriptorProto.TYPE_FIXED32: "ProtoSchema.DataType.FIXED32", + FieldDescriptorProto.TYPE_BOOL: "ProtoSchema.DataType.BOOL", + FieldDescriptorProto.TYPE_STRING: "ProtoSchema.DataType.STRING", + FieldDescriptorProto.TYPE_MESSAGE: "ProtoSchema.DataType.MESSAGE", + FieldDescriptorProto.TYPE_BYTES: "ProtoSchema.DataType.BYTES", + FieldDescriptorProto.TYPE_UINT32: "ProtoSchema.DataType.UINT32", + FieldDescriptorProto.TYPE_ENUM: "ProtoSchema.DataType.ENUM", + FieldDescriptorProto.TYPE_SFIXED32: "ProtoSchema.DataType.SFIXED32", + FieldDescriptorProto.TYPE_SFIXED64: "ProtoSchema.DataType.SFIXED64", + FieldDescriptorProto.TYPE_SINT32: "ProtoSchema.DataType.SINT32", + FieldDescriptorProto.TYPE_SINT64: "ProtoSchema.DataType.SINT64", } def get_wire_type(field_type: int) -> str: """Maps a protobuf field type to Lua's wire type representation.""" - return WIRE_TYPE_MAP.get(field_type, "PROTOBUF_SCHEMA.WireType.UNKNOWN") + return WIRE_TYPE_MAP.get(field_type, "ProtoSchema.WireType.UNKNOWN") def get_type(field_type: int) -> str: """Maps a protobuf field type to Lua's wire type representation.""" - return TYPE_MAP.get(field_type, "PROTOBUF_SCHEMA.WIRE.UNKNOWN") + return TYPE_MAP.get(field_type, "ProtoSchema.WIRE.UNKNOWN") def get_lua_type(field: FieldDescriptorProto) -> str: @@ -187,7 +188,8 @@ def parse_message( indent: int = 0, ) -> str: """Parse a message descriptor into a Lua table.""" - lua = (" " * indent) + f"{message.name} = {{\n" + lua = (" " * indent) + f"--- @type ProtoMessageSchema\n" + lua += (" " * indent) + f"ProtoSchema.Message.{message.name} = {{\n" lua += (" " * (indent + 1)) + f'name = "{message.name}",\n' # Build options table @@ -233,13 +235,13 @@ def parse_message( else: lua += "},\n" - return lua + (" " * indent) + "},\n" + return lua + (" " * indent) + "}\n\n" def parse_enum(enum, indent: int = 0) -> str: """Parse an enum descriptor into a Lua table.""" lua = (" " * indent) + f"--- @enum Proto{enum.name}\n" - lua += (" " * indent) + f"{enum.name} = {{" + lua += (" " * indent) + f"ProtoSchema.Enum.{enum.name} = {{" new_line_added = False for value in enum.value: if not new_line_added: @@ -247,9 +249,9 @@ def parse_enum(enum, indent: int = 0) -> str: new_line_added = True lua += parse_enum_value(value, indent=indent + 1) if new_line_added: - lua += (" " * indent) + "},\n" + lua += (" " * indent) + "}\n\n" else: - lua += "},\n" + lua += "}\n\n" return lua @@ -262,17 +264,18 @@ def parse_service_method( lua += (" " * (indent + 1)) + f'method = "{method.name}",\n' lua += ( " " * (indent + 1) - ) + f"inputType = PROTOBUF_SCHEMA.Message{method.input_type},\n" + ) + f"inputType = ProtoSchema.Message{method.input_type},\n" lua += ( " " * (indent + 1) - ) + f"outputType = PROTOBUF_SCHEMA.Message{method.output_type},\n" + ) + f"outputType = ProtoSchema.Message{method.output_type},\n" lua += (" " * indent) + "},\n" return lua def parse_service(service, indent: int = 0) -> str: """Parse a service descriptor into a Lua table.""" - lua = (" " * indent) + f"{service.name} = {{" + lua = (" " * indent) + f"--- @type ProtoServiceSchema\n" + lua += (" " * indent) + f"ProtoSchema.RPC.{service.name} = {{" new_line_added = False for method in service.method: if not new_line_added: @@ -280,9 +283,9 @@ def parse_service(service, indent: int = 0) -> str: new_line_added = True lua += parse_service_method(service, method, indent=indent + 1) if new_line_added: - lua += (" " * indent) + "},\n" + lua += (" " * indent) + "}\n\n" else: - lua += "},\n" + lua += "}\n\n" return lua @@ -304,7 +307,6 @@ def parse_descriptor_set(fds: FileDescriptorSet) -> str: lua = "" # Generate class annotations for all message types first - lua += "--\n-- Message type annotations for IDE support\n--\n\n" for proto_file in fds.file: if proto_file.package == "google.protobuf": continue @@ -314,51 +316,66 @@ def parse_descriptor_set(fds: FileDescriptorSet) -> str: lua += generate_message_class_annotation(message, indent=0) lua += "\n" - lua += "PROTOBUF_SCHEMA.Enum = {\n" enums = [] + any_content_added = False for proto_file in fds.file: if proto_file.package == "google.protobuf": continue - if proto_file.package: - lua += f"-- Package: {proto_file.package}\n" + comment_added = False for message in proto_file.enum_type: - lua += parse_enum(message, indent=1) + any_content_added = True + if not comment_added and proto_file.package: + lua += f"-- Package: {proto_file.package}\n" + comment_added = True + lua += parse_enum(message, indent=0) enums.append(f"Proto{message.name}") - lua += "}\n\n" - lua += f"--- @alias ProtoEnum " + "|".join(enums) + "\n\n" + if any_content_added: + lua += "\n" - lua += "PROTOBUF_SCHEMA.Message = {\n" + any_content_added = False for proto_file in fds.file: if proto_file.package == "google.protobuf": continue - if proto_file.package: - lua += f"-- Package: {proto_file.package}\n" + comment_added = False for message in proto_file.message_type: + any_content_added = True + if not comment_added and proto_file.package: + lua += f"-- Package: {proto_file.package}\n" + comment_added = True lua += parse_message( - message, message_option_extensions=message_option_extensions, indent=1 + message, message_option_extensions=message_option_extensions, indent=0 ) - lua += "}\n\n" + if any_content_added: + lua += "\n" - lua += "PROTOBUF_SCHEMA.RPC = {\n" + any_content_added = False for proto_file in fds.file: if proto_file.package == "google.protobuf": continue - if proto_file.package: - lua += f"-- Package: {proto_file.package}\n" + comment_added = False for service in proto_file.service: - lua += parse_service(service, indent=1) - lua += "}\n" + any_content_added = True + if not comment_added and proto_file.package: + lua += f"-- Package: {proto_file.package}\n" + comment_added = True + lua += parse_service(service, indent=0) + if any_content_added: + lua += "\n" return lua def resolve_input(input_file_or_url: str, tmpdir: str) -> str: - """Resolve an input file or URL to a local file path.""" + """Resolve an input file or URL to a local file path in tmpdir.""" if not os.path.isdir(tmpdir): raise ValueError("Temporary directory must be provided for URL inputs.") if os.path.isfile(input_file_or_url): - return input_file_or_url + # Copy local file to tmpdir so protoc can find it + filename = os.path.basename(input_file_or_url) + dest_path = os.path.join(tmpdir, filename) + shutil.copy(input_file_or_url, dest_path) + return dest_path parsed_url = urlparse(input_file_or_url) if parsed_url.scheme not in ("http", "https"): @@ -411,49 +428,32 @@ def main(): -- Generated Lua schema from protobuf descriptor set -- Do not edit manually - --- Int64 values may be represented as a {{high, low}} table when the value - --- exceeds Lua's number precision (values > 2^53). - --- @alias Int64HighLow {{[1]: number, [2]: number}} - - --- @class ProtoFieldSchema - --- @field name string The name of the field. - --- @field wireType ProtoWireType The protobuf wire type (see ProtoSchema.WireType). - --- @field type ProtoDataType The protobuf type (see ProtoSchema.DataType). - --- @field repeated boolean? Whether the field is repeated (optional). - --- @field subschema string? The subschema name for nested messages (optional). - - --- @class ProtoMessageSchema - --- @field name string The name of the message type. - --- @field options table Message options. - --- @field fields table A map of field numbers to ProtoFieldSchema definitions. - - --- @class ProtoServiceMethodSchema - --- @field service string The name of the service. - --- @field method string The method name. - --- @field inputType ProtoMessageSchema The protobuf message type for the request. - --- @field outputType ProtoMessageSchema The protobuf message type for the response. - - --- @class ProtoServiceSchema - --- @field [string] ProtoServiceMethodSchema Maps method names to their method definitions. - --- @class ProtoSchema - --- @field WireType ProtoWireType Maps protobuf wire types to their integer values. - --- @field DataType ProtoDataType Maps protobuf data types to their integer values. - --- @field Message table Maps message names to their definitions. - --- @field Enum table Maps enum names to their definitions. - --- @field RPC table Maps service names to their method definitions. - local PROTOBUF_SCHEMA = {{}} + local ProtoSchema = {{}} + --- Maps enum names to their definitions. + ProtoSchema.Enum = {{}} + + --- Maps message names to their definitions. + --- @type table + ProtoSchema.Message = {{}} + + --- Maps service names to their method definitions. + --- @type table + ProtoSchema.RPC = {{}} + + --- ProtoWireType Maps protobuf wire types to their integer values. --- @enum ProtoWireType - PROTOBUF_SCHEMA.WireType = {{ + ProtoSchema.WireType = {{ VARINT = 0, FIXED64 = 1, LENGTH_DELIMITED = 2, FIXED32 = 5, }} + --- ProtoDataType Maps protobuf data types to their integer values. --- @enum ProtoDataType - PROTOBUF_SCHEMA.DataType = {{ + ProtoSchema.DataType = {{ DOUBLE = 1, FLOAT = 2, INT64 = 3, @@ -472,12 +472,33 @@ def main(): SINT32 = 17, SINT64 = 18, }} + + --- @class ProtoFieldSchema + --- @field name string The name of the field. + --- @field wireType ProtoWireType The protobuf wire type (see ProtoSchema.WireType). + --- @field type ProtoDataType The protobuf type (see ProtoSchema.DataType). + --- @field repeated boolean? Whether the field is repeated (optional). + --- @field subschema string? The subschema name for nested messages (optional). + + --- @class ProtoMessageSchema + --- @field name string The name of the message type. + --- @field options table Message options. + --- @field fields table A map of field numbers to ProtoFieldSchema definitions. + + --- @class ProtoServiceMethodSchema + --- @field service string The name of the service. + --- @field method string The method name. + --- @field inputType ProtoMessageSchema The protobuf message type for the request. + --- @field outputType ProtoMessageSchema The protobuf message type for the response. + + --- @class ProtoServiceSchema + --- @field [string] ProtoServiceMethodSchema Maps method names to their method definitions. """ ) ) f.write(lua_content) - f.write("\nreturn PROTOBUF_SCHEMA\n") + f.write("return ProtoSchema\n") print(f"Schema successfully generated and saved to {args.output}") diff --git a/vendor/bitn.lua b/vendor/bitn.lua index f79e11d..b434f48 100644 --- a/vendor/bitn.lua +++ b/vendor/bitn.lua @@ -1245,9 +1245,15 @@ end --- Warning: Lua numbers use 64-bit IEEE 754 doubles with 53-bit mantissa precision. --- Values exceeding 53 bits (greater than 9007199254740991) will lose precision. --- To maintain full 64-bit precision, keep values in {high, low} format. ---- @param value number|Int64HighLow The {high_32, low_32} pair (or number to pass through). +--- @param value number|integer|Int64HighLow The {high_32, low_32} pair (or number to pass through). --- @param strict? boolean If true, errors when value exceeds 53-bit precision. ---- @return number result The value as a Lua number (may lose precision for large values unless strict). +--- @return number|integer result The value as a Lua number (may lose precision for large values unless strict). +--- @overload fun(value: number, strict?: boolean): number +--- @overload fun(value: integer, strict?: boolean): integer +--- @overload fun(value: Int64HighLow, strict?: boolean): integer +--- @overload fun(value: number): number +--- @overload fun(value: integer): integer +--- @overload fun(value: Int64HighLow): integer function bit64.to_number(value, strict) if type(value) == "number" then return value @@ -1910,13 +1916,16 @@ end --- --- @class bitn local bitn = { + --- @type bit16 bit16 = require("bitn.bit16"), + --- @type bit32 bit32 = require("bitn.bit32"), + --- @type bit64 bit64 = require("bitn.bit64"), } --- Library version (injected at build time for releases). -local VERSION = "v0.4.1" +local VERSION = "v0.4.2" --- Get the library version string. --- @return string version Version string (e.g., "v1.0.0" or "dev")