Rewrite str2wcstring to properly handle embedded nulls, and be simpler

This commit is contained in:
ridiculousfish 2012-12-20 12:25:35 -08:00
parent d5af389d2e
commit ce15abd577
3 changed files with 71 additions and 89 deletions

View file

@ -81,8 +81,6 @@ parts of fish.
#include "fallback.cpp"
static wchar_t *str2wcs_internal(const char *in, const size_t in_len, wchar_t *out);
struct termios shell_modes;
// Note we foolishly assume that pthread_t is just a primitive. But it might be a struct.
@ -164,45 +162,6 @@ int fgetws2(wcstring *s, FILE *f)
}
}
static wchar_t *str2wcs(const char *in)
{
size_t len = strlen(in);
wchar_t *out = (wchar_t *)malloc(sizeof(wchar_t)*(len+1));
if (!out)
{
DIE_MEM();
}
return str2wcs_internal(in, strlen(in), out);
}
wcstring str2wcstring(const char *in, size_t len)
{
assert(in != NULL);
std::string tmp_str(in, len);
wchar_t *tmp = str2wcs(tmp_str.c_str());
wcstring result = tmp;
free(tmp);
return result;
}
wcstring str2wcstring(const char *in)
{
assert(in != NULL);
wchar_t *tmp = str2wcs(in);
wcstring result = tmp;
free(tmp);
return result;
}
wcstring str2wcstring(const std::string &in)
{
wchar_t *tmp = str2wcs(in.c_str());
wcstring result = tmp;
free(tmp);
return result;
}
/**
Converts the narrow character string \c in into it's wide
equivalent, stored in \c out. \c out must have enough space to fit
@ -213,63 +172,87 @@ wcstring str2wcstring(const std::string &in)
This function encodes illegal character sequences in a reversible
way using the private use area.
*/
static wchar_t *str2wcs_internal(const char *in, const size_t in_len, wchar_t *out)
static wcstring str2wcs_internal(const char *in, const size_t in_len)
{
size_t res=0;
if (in_len == 0)
return wcstring();
assert(in != NULL);
wcstring result;
result.reserve(in_len);
mbstate_t state = {};
size_t in_pos = 0;
size_t out_pos = 0;
mbstate_t state;
CHECK(in, 0);
CHECK(out, 0);
memset(&state, 0, sizeof(state));
while (in[in_pos])
while (in_pos < in_len)
{
res = mbrtowc(&out[out_pos], &in[in_pos], in_len-in_pos, &state);
wchar_t wc = 0;
size_t ret = mbrtowc(&wc, &in[in_pos], in_len-in_pos, &state);
if (((out[out_pos] >= ENCODE_DIRECT_BASE) &&
(out[out_pos] < ENCODE_DIRECT_BASE+256)) ||
(out[out_pos] == INTERNAL_SEPARATOR))
/* Determine whether to encode this characters with our crazy scheme */
bool use_encode_direct = false;
if (wc >= ENCODE_DIRECT_BASE && wc < ENCODE_DIRECT_BASE+256)
{
out[out_pos] = ENCODE_DIRECT_BASE + (unsigned char)in[in_pos];
use_encode_direct = true;
}
else if (wc == INTERNAL_SEPARATOR)
{
use_encode_direct = true;
}
else if (ret == (size_t)(-2))
{
/* Incomplete sequence */
use_encode_direct = true;
}
else if (ret == (size_t)(-1))
{
/* Invalid data */
use_encode_direct = true;
}
else if (ret > in_len - in_pos)
{
/* Other error codes? Terrifying, should never happen */
use_encode_direct = true;
}
if (use_encode_direct)
{
wc = ENCODE_DIRECT_BASE + (unsigned char)in[in_pos];
result.push_back(wc);
in_pos++;
memset(&state, 0, sizeof(state));
out_pos++;
bzero(&state, sizeof state);
}
else if (ret == 0)
{
/* Embedded null byte! */
result.push_back(L'\0');
in_pos++;
bzero(&state, sizeof state);
}
else
{
/* Normal case */
result.push_back(wc);
in_pos += ret;
}
}
return result;
}
switch (res)
wcstring str2wcstring(const char *in, size_t len)
{
case (size_t)(-2):
case (size_t)(-1):
return str2wcs_internal(in, len);
}
wcstring str2wcstring(const char *in)
{
out[out_pos] = ENCODE_DIRECT_BASE + (unsigned char)in[in_pos];
in_pos++;
memset(&state, 0, sizeof(state));
break;
return str2wcs_internal(in, strlen(in));
}
case 0:
wcstring str2wcstring(const std::string &in)
{
return out;
}
default:
{
in_pos += res;
break;
}
}
out_pos++;
}
}
out[out_pos] = 0;
return out;
/* Handles embedded nulls! */
return str2wcs_internal(in.data(), in.size());
}
char *wcs2str(const wchar_t *in)

View file

@ -284,7 +284,6 @@ static void test_convert()
/* Verify correct behavior with embedded nulls */
static void test_convert_nulls(void)
{
return;
say(L"Testing embedded nulls in string conversion");
const wchar_t in[] = L"AAA\0BBB";
const size_t in_len = (sizeof in / sizeof *in) - 1;