X-Git-Url: https://ruderich.org/simon/gitweb/?a=blobdiff_plain;f=src%2Ftrackfds.h;h=30693f47d1a07888bb578c665227f2ae3afc36d6;hb=0d7f3068981f2b08e583cec21d9069e97c73addd;hp=b9089bbd25b0444772c6442ca27dedd3ac7e4334;hpb=6d6ad4423ae87771bd44f90006a648ec03931961;p=coloredstderr%2Fcoloredstderr.git diff --git a/src/trackfds.h b/src/trackfds.h index b9089bb..30693f4 100644 --- a/src/trackfds.h +++ b/src/trackfds.h @@ -72,8 +72,9 @@ static int init_tracked_fds_list(size_t count) { /* Load tracked file descriptors from the environment. The environment is used * to pass the information to child processes. * - * ENV_NAME_FDS has the following format: Each descriptor as string followed - * by a comma; there's a trailing comma. Example: "2,4,". */ + * ENV_NAME_FDS and ENV_NAME_PRIVATE_FDS have the following format: Each + * descriptor as string followed by a comma; there's a trailing comma. + * Example: "2,4,". */ static void init_from_environment(void) { #ifdef DEBUG debug("init_from_environment()\t\t[%d]\n", getpid()); @@ -92,15 +93,24 @@ static void init_from_environment(void) { force_write_to_non_tty = 1; } + /* Prefer user defined list of file descriptors, fall back to file + * descriptors passed through the environment from the parent process. */ env = getenv(ENV_NAME_FDS); + if (env) { + used_fds_set_by_user = 1; + } else { + env = getenv(ENV_NAME_PRIVATE_FDS); + } if (!env) { errno = saved_errno; return; } #ifdef DEBUG debug(" getenv(\"%s\"): \"%s\"\n", ENV_NAME_FDS, env); + debug(" getenv(\"%s\"): \"%s\"\n", ENV_NAME_PRIVATE_FDS, env); #endif - /* Environment is read-only. */ + + /* Environment must be treated read-only. */ char env_copy[strlen(env) + 1]; strcpy(env_copy, env); @@ -124,18 +134,21 @@ static void init_from_environment(void) { } /* ',' at the beginning or double ',' - ignore. */ if (x == last) { - last = x + 1; - continue; + goto next; } if (i == count) { break; } + /* Replace ',' to null-terminate number for atoi(). */ *x = 0; int fd = atoi(last); - if (fd < TRACKFDS_STATIC_COUNT) { + if (fd < 0) { + goto next; + + } else if (fd < TRACKFDS_STATIC_COUNT) { tracked_fds[fd] = 1; } else { if (!tracked_fds_list) { @@ -169,7 +182,7 @@ static char *update_environment_buffer_entry(char *x, int fd) { assert(fd >= 0); int length = snprintf(x, 10 + 1, "%d", fd); - if (length >= 10 + 1) { + if (length >= 10 + 1 || length <= 0 /* shouldn't happen */) { /* Integer too big to fit the buffer, skip it. */ #ifdef WARNING warning("update_environment_buffer_entry(): truncated fd: %d [%d]\n", @@ -181,7 +194,7 @@ static char *update_environment_buffer_entry(char *x, int fd) { /* Write comma after number. */ x += length; *x++ = ','; - /* Make sure the string is always zero terminated. */ + /* Make sure the string is always null-terminated. */ *x = 0; return x; @@ -222,16 +235,29 @@ static void update_environment(void) { return; } + int saved_errno = errno; + char env[update_environment_buffer_size()]; env[0] = 0; update_environment_buffer(env); #if 0 - debug(" setenv(\"%s\", \"%s\", 1)\n", ENV_NAME_FDS, env); + debug(" setenv(\"%s\", \"%s\", 1)\n", ENV_NAME_PRIVATE_FDS, env); #endif + setenv(ENV_NAME_PRIVATE_FDS, env, 1 /* overwrite */); + + /* Child processes must use ENV_NAME_PRIVATE_FDS to get the updated list + * of tracked file descriptors, not the static list provided by the user + * in ENV_NAME_FDS. + * + * But only remove it if the static list in ENV_NAME_FDS was loaded by + * init_from_environment() and merged into ENV_NAME_PRIVATE_FDS. */ + if (used_fds_set_by_user) { + unsetenv(ENV_NAME_FDS); + } - setenv(ENV_NAME_FDS, env, 1 /* overwrite */); + errno = saved_errno; } @@ -325,8 +351,12 @@ static int tracked_fds_find_slow(int fd) noinline; * they are not called often enough. */ inline static int tracked_fds_find(int fd) always_inline; -static int tracked_fds_find(int fd) { - assert(fd >= 0); +inline static int tracked_fds_find(int fd) { + /* Invalid file descriptor. No assert() as we're called from the hooked + * macro. */ + if (unlikely(fd < 0)) { + return 0; + } if (fd < TRACKFDS_STATIC_COUNT) { return tracked_fds[fd]; @@ -336,6 +366,7 @@ static int tracked_fds_find(int fd) { } static int tracked_fds_find_slow(int fd) { assert(initialized); + assert(fd >= 0); if (tracked_fds_list_count == 0) { return 0;