commit 0c9bec0d38ed3d2c45d7be4e764a0bcffef98be1
parent c31d6774ac7db4cfbc548ce507ae65ab6036f873
Author: Roberto Ierusalimschy <roberto@inf.puc-rio.br>
Date: Wed, 7 Feb 2024 13:39:27 -0300
Better handling of size limit when resizing a table
Avoid silent conversions from int to unsigned int when calling
'luaH_resize'; avoid silent conversions from lua_Integer to int in
'table.create'; MAXASIZE corrected for the new implementation of arrays;
'luaH_resize' checks explicitly whether new size respects MAXASIZE.
(Even constructors were bypassing that check.)
Diffstat:
6 files changed, 37 insertions(+), 21 deletions(-)
diff --git a/lapi.c b/lapi.c
@@ -781,7 +781,7 @@ LUA_API int lua_rawgetp (lua_State *L, int idx, const void *p) {
}
-LUA_API void lua_createtable (lua_State *L, int narray, int nrec) {
+LUA_API void lua_createtable (lua_State *L, unsigned narray, unsigned nrec) {
Table *t;
lua_lock(L);
t = luaH_new(L);
diff --git a/ltable.c b/ltable.c
@@ -61,18 +61,25 @@ typedef union {
/*
-** MAXABITS is the largest integer such that MAXASIZE fits in an
+** MAXABITS is the largest integer such that 2^MAXABITS fits in an
** unsigned int.
*/
#define MAXABITS cast_int(sizeof(int) * CHAR_BIT - 1)
/*
+** MAXASIZEB is the maximum number of elements in the array part such
+** that the size of the array fits in 'size_t'.
+*/
+#define MAXASIZEB ((MAX_SIZET/sizeof(ArrayCell)) * NM)
+
+
+/*
** MAXASIZE is the maximum size of the array part. It is the minimum
-** between 2^MAXABITS and the maximum size that, measured in bytes,
-** fits in a 'size_t'.
+** between 2^MAXABITS and MAXASIZEB.
*/
-#define MAXASIZE luaM_limitN(1u << MAXABITS, TValue)
+#define MAXASIZE \
+ (((1u << MAXABITS) < MAXASIZEB) ? (1u << MAXABITS) : cast_uint(MAXASIZEB))
/*
** MAXHBITS is the largest integer such that 2^MAXHBITS fits in a
@@ -663,6 +670,8 @@ void luaH_resize (lua_State *L, Table *t, unsigned int newasize,
Table newt; /* to keep the new hash part */
unsigned int oldasize = setlimittosize(t);
ArrayCell *newarray;
+ if (newasize > MAXASIZE)
+ luaG_runerror(L, "table overflow");
/* create new hash part with appropriate size into 'newt' */
newt.flags = 0;
setnodevector(L, &newt, nhsize);
diff --git a/ltablib.c b/ltablib.c
@@ -59,8 +59,10 @@ static void checktab (lua_State *L, int arg, int what) {
static int tcreate (lua_State *L) {
- int sizeseq = (int)luaL_checkinteger(L, 1);
- int sizerest = (int)luaL_optinteger(L, 2, 0);
+ lua_Unsigned sizeseq = (lua_Unsigned)luaL_checkinteger(L, 1);
+ lua_Unsigned sizerest = (lua_Unsigned)luaL_optinteger(L, 2, 0);
+ luaL_argcheck(L, sizeseq <= UINT_MAX, 1, "out of range");
+ luaL_argcheck(L, sizerest <= UINT_MAX, 2, "out of range");
lua_createtable(L, sizeseq, sizerest);
return 1;
}
diff --git a/lua.h b/lua.h
@@ -268,7 +268,7 @@ LUA_API int (lua_rawget) (lua_State *L, int idx);
LUA_API int (lua_rawgeti) (lua_State *L, int idx, lua_Integer n);
LUA_API int (lua_rawgetp) (lua_State *L, int idx, const void *p);
-LUA_API void (lua_createtable) (lua_State *L, int narr, int nrec);
+LUA_API void (lua_createtable) (lua_State *L, unsigned narr, unsigned nrec);
LUA_API void *(lua_newuserdatauv) (lua_State *L, size_t sz, int nuvalue);
LUA_API int (lua_getmetatable) (lua_State *L, int objindex);
LUA_API int (lua_getiuservalue) (lua_State *L, int idx, int n);
diff --git a/manual/manual.of b/manual/manual.of
@@ -3234,7 +3234,7 @@ Values at other positions are not affected.
}
-@APIEntry{void lua_createtable (lua_State *L, int nseq, int nrec);|
+@APIEntry{void lua_createtable (lua_State *L, unsigned nseq, unsigned nrec);|
@apii{0,1,m}
Creates a new empty table and pushes it onto the stack.
diff --git a/testes/sort.lua b/testes/sort.lua
@@ -3,19 +3,30 @@
print "testing (parts of) table library"
+local maxI = math.maxinteger
+local minI = math.mininteger
+
+
+local function checkerror (msg, f, ...)
+ local s, err = pcall(f, ...)
+ assert(not s and string.find(err, msg))
+end
+
+
do print "testing 'table.create'"
+ local N = 10000
collectgarbage()
local m = collectgarbage("count") * 1024
- local t = table.create(10000)
+ local t = table.create(N)
local memdiff = collectgarbage("count") * 1024 - m
- assert(memdiff > 10000 * 4)
+ assert(memdiff > N * 4)
for i = 1, 20 do
assert(#t == i - 1)
t[i] = 0
end
for i = 1, 20 do t[#t + 1] = i * 10 end
assert(#t == 40 and t[39] == 190)
- assert(not T or T.querytab(t) == 10000)
+ assert(not T or T.querytab(t) == N)
t = nil
collectgarbage()
m = collectgarbage("count") * 1024
@@ -23,6 +34,9 @@ do print "testing 'table.create'"
memdiff = collectgarbage("count") * 1024 - m
assert(memdiff > 1024 * 12)
assert(not T or select(2, T.querytab(t)) == 1024)
+
+ checkerror("table overflow", table.create, (1<<31) + 1)
+ checkerror("table overflow", table.create, 0, (1<<31) + 1)
end
@@ -30,15 +44,6 @@ print "testing unpack"
local unpack = table.unpack
-local maxI = math.maxinteger
-local minI = math.mininteger
-
-
-local function checkerror (msg, f, ...)
- local s, err = pcall(f, ...)
- assert(not s and string.find(err, msg))
-end
-
checkerror("wrong number of arguments", table.insert, {}, 2, 3, 4)