- Use BN_set_negative() and BN_is_negative() instead of subtracting or

comparing to zero.
- Fix fractional number exponentiation, especially for negative exponents.

Obtained from:	OpenBSD
This commit is contained in:
Kevin Lo 2012-12-30 15:20:27 +00:00
parent 2befb3613f
commit c3761c3848
3 changed files with 66 additions and 42 deletions

View File

@ -29,8 +29,6 @@ __FBSDID("$FreeBSD$");
#include "extern.h"
BIGNUM zero;
#define __inline
#define MAX_ARRAY_INDEX 2048
@ -250,8 +248,12 @@ init_bmachine(bool extended_registers)
if (bmachine.readstack == NULL)
err(1, NULL);
bmachine.obase = bmachine.ibase = 10;
BN_init(&zero);
bn_check(BN_zero(&zero));
}
u_int
bmachine_scale(void)
{
return (bmachine.scale);
}
/* Reset the things needed before processing a (new) file */
@ -407,7 +409,7 @@ split_number(const struct number *n, BIGNUM *i, BIGNUM *f)
}
}
__inline void
void
normalize(struct number *n, u_int s)
{
@ -427,7 +429,7 @@ void
negate(struct number *n)
{
bn_check(BN_sub(n->number, &zero, n->number));
BN_set_negative(n->number, !BN_is_negative(n->number));
}
static __inline void
@ -581,7 +583,7 @@ set_scale(void)
n = pop_number();
if (n != NULL) {
if (BN_cmp(n->number, &zero) < 0)
if (BN_is_negative(n->number))
warnx("scale must be a nonnegative number");
else {
scale = get_ulong(n);
@ -878,7 +880,7 @@ load_array(void)
if (inumber == NULL)
return;
idx = get_ulong(inumber);
if (BN_cmp(inumber->number, &zero) < 0)
if (BN_is_negative(inumber->number))
warnx("negative idx");
else if (idx == BN_MASK2 || idx > MAX_ARRAY_INDEX)
warnx("idx too big");
@ -917,7 +919,7 @@ store_array(void)
return;
}
idx = get_ulong(inumber);
if (BN_cmp(inumber->number, &zero) < 0) {
if (BN_is_negative(inumber->number)) {
warnx("negative idx");
stack_free_value(value);
} else if (idx == BN_MASK2 || idx > MAX_ARRAY_INDEX) {
@ -1009,7 +1011,7 @@ bsub(void)
}
void
bmul_number(struct number *r, struct number *a, struct number *b)
bmul_number(struct number *r, struct number *a, struct number *b, u_int scale)
{
BN_CTX *ctx;
@ -1023,11 +1025,9 @@ bmul_number(struct number *r, struct number *a, struct number *b)
bn_check(BN_mul(r->number, a->number, b->number, ctx));
BN_CTX_free(ctx);
if (rscale > bmachine.scale && rscale > ascale && rscale > bscale) {
r->scale = rscale;
normalize(r, max(bmachine.scale, max(ascale, bscale)));
} else
r->scale = rscale;
r->scale = rscale;
if (rscale > bmachine.scale && rscale > ascale && rscale > bscale)
normalize(r, max(scale, max(ascale, bscale)));
}
static void
@ -1046,7 +1046,7 @@ bmul(void)
}
r = new_number();
bmul_number(r, a, b);
bmul_number(r, a, b, bmachine.scale);
push_number(r);
free_number(a);
@ -1172,7 +1172,7 @@ static void
bexp(void)
{
struct number *a, *p, *r;
u_int scale;
u_int rscale;
bool neg;
p = pop_number();
@ -1185,15 +1185,27 @@ bexp(void)
return;
}
if (p->scale != 0)
warnx("Runtime warning: non-zero scale in exponent");
if (p->scale != 0) {
BIGNUM *i, *f;
i = BN_new();
bn_checkp(i);
f = BN_new();
bn_checkp(f);
split_number(p, i, f);
if (!BN_is_zero(f))
warnx("Runtime warning: non-zero fractional part "
"in exponent");
BN_free(i);
BN_free(f);
}
normalize(p, 0);
neg = false;
if (BN_cmp(p->number, &zero) < 0) {
if (BN_is_negative(p->number)) {
neg = true;
negate(p);
scale = bmachine.scale;
rscale = bmachine.scale;
} else {
/* Posix bc says min(a.scale * b, max(a.scale, scale) */
u_long b;
@ -1201,30 +1213,37 @@ bexp(void)
b = BN_get_word(p->number);
m = max(a->scale, bmachine.scale);
scale = a->scale * (u_int)b;
if (scale > m || (a->scale > 0 && (b == BN_MASK2 ||
rscale = a->scale * (u_int)b;
if (rscale > m || (a->scale > 0 && (b == BN_MASK2 ||
b > UINT_MAX)))
scale = m;
rscale = m;
}
if (BN_is_zero(p->number)) {
r = new_number();
bn_check(BN_one(r->number));
normalize(r, scale);
normalize(r, rscale);
} else {
u_int ascale, mscale;
ascale = a->scale;
while (!BN_is_bit_set(p->number, 0)) {
bmul_number(a, a, a);
ascale *= 2;
bmul_number(a, a, a, ascale);
bn_check(BN_rshift1(p->number, p->number));
}
r = dup_number(a);
normalize(r, scale);
bn_check(BN_rshift1(p->number, p->number));
mscale = ascale;
while (!BN_is_zero(p->number)) {
bmul_number(a, a, a);
if (BN_is_bit_set(p->number, 0))
bmul_number(r, r, a);
ascale *= 2;
bmul_number(a, a, a, ascale);
if (BN_is_bit_set(p->number, 0)) {
mscale += ascale;
bmul_number(r, r, a, mscale);
}
bn_check(BN_rshift1(p->number, p->number));
}
@ -1237,13 +1256,18 @@ bexp(void)
bn_check(BN_one(one));
ctx = BN_CTX_new();
bn_checkp(ctx);
scale_number(one, r->scale + scale);
normalize(r, scale);
bn_check(BN_div(r->number, NULL, one, r->number, ctx));
scale_number(one, r->scale + rscale);
if (BN_is_zero(r->number))
warnx("divide by zero");
else
bn_check(BN_div(r->number, NULL, one,
r->number, ctx));
BN_free(one);
BN_CTX_free(ctx);
r->scale = rscale;
} else
normalize(r, scale);
normalize(r, rscale);
}
push_number(r);
free_number(a);
@ -1282,7 +1306,7 @@ bsqrt(void)
if (BN_is_zero(n->number)) {
r = new_number();
push_number(r);
} else if (BN_cmp(n->number, &zero) < 0)
} else if (BN_is_negative(n->number))
warnx("square root of negative number");
else {
scale = max(bmachine.scale, n->scale);

View File

@ -85,6 +85,7 @@ struct source {
void init_bmachine(bool);
void reset_bmachine(struct source *);
u_int bmachine_scale(void);
void scale_number(BIGNUM *, int);
void normalize(struct number *, u_int);
void eval(void);
@ -93,6 +94,4 @@ void pbn(const char *, const BIGNUM *);
void negate(struct number *);
void split_number(const struct number *, BIGNUM *, BIGNUM *);
void bmul_number(struct number *, struct number *,
struct number *);
extern BIGNUM zero;
struct number *, u_int);

View File

@ -322,7 +322,7 @@ printnumber(FILE *f, const struct number *b, u_int base)
i++;
}
sz = i;
if (BN_cmp(b->number, &zero) < 0)
if (BN_is_negative(b->number))
putcharwrap(f, '-');
for (i = 0; i < sz; i++) {
p = stack_popstring(&stack);
@ -353,7 +353,8 @@ printnumber(FILE *f, const struct number *b, u_int base)
putcharwrap(f, ' ');
i = 1;
bmul_number(fract_part, fract_part, num_base);
bmul_number(fract_part, fract_part, num_base,
bmachine_scale());
split_number(fract_part, int_part->number, NULL);
rem = BN_get_word(int_part->number);
p = get_digit(rem, digits, base);
@ -402,8 +403,8 @@ print_ascii(FILE *f, const struct number *n)
v = BN_dup(n->number);
bn_checkp(v);
if (BN_cmp(v, &zero) < 0)
bn_check(BN_sub(v, &zero, v));
if (BN_is_negative(v))
BN_set_negative(v, 0);
numbits = BN_num_bytes(v) * 8;
while (numbits > 0) {