/*
 * Copyright 1995,96 Thierry Bousch
 * Licensed under the Gnu Public License, Version 2
 *
 * $Id: Poly.c,v 2.5 1996/08/18 09:26:40 bousch Exp $
 *
 * Operations on Polynomials
 */

#include <assert.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include "saml.h"
#include "saml-errno.h"
#include "mnode.h"
#include "builtin.h"

static gr_string* poly_stringify (std_mnode*);
static s_mnode* poly_make (s_mnode*);
static s_mnode* poly_add (std_mnode*, std_mnode*);
static s_mnode* poly_mul (std_mnode*, std_mnode*);
static s_mnode* poly_div (s_mnode*, s_mnode*);
static s_mnode* poly_gcd (std_mnode*, std_mnode*);
static int poly_notzero (std_mnode*);
static s_mnode* poly_zero (std_mnode*);
static s_mnode* poly_negate (std_mnode*);
static s_mnode* poly_one (std_mnode*);
static s_mnode* literal2poly (s_mnode*, std_mnode*);
static s_mnode* mono2poly (s_mnode*, std_mnode*);
static s_mnode* poly2poly (std_mnode*, std_mnode*);

extern s_mnode* decompose_powers_umono (std_mnode*, s_mnode*);
extern std_mnode* mono_unpack (s_mnode*);
extern s_mnode* upoly_eval (s_mnode*, s_mnode*);

static unsafe_s_mtype MathType_Polynomial = {
	"Polynomial",
	mstd_free, NULL, poly_stringify,
	poly_make, NULL,
	poly_add, mn_std_sub, poly_mul, poly_div, poly_gcd,
	poly_notzero, NULL, NULL, mn_std_differ, NULL,
	poly_zero, poly_negate, poly_one, NULL, NULL
};

void init_MathType_Polynomial (void)
{
	register_mtype(ST_POLY, &MathType_Polynomial);
	register_CV_routine(ST_LITERAL, ST_POLY, literal2poly);
	register_CV_routine(ST_MONO, ST_POLY, mono2poly);
	register_CV_routine(ST_POLY, ST_POLY, poly2poly);
}

static s_mnode* poly_zero (std_mnode* mn1)
{
	std_mnode* mn = mstd_alloc(ST_POLY, 1);
	mn->x[0] = copy_mnode(mn1->x[0]);
	return (mn_ptr)mn;
}

static s_mnode* poly_one (std_mnode* mn1)
{
	std_mnode* mn = mstd_alloc(ST_POLY, 2);
	mn->x[0] = copy_mnode(mn1->x[0]);
	mn->x[1] = mnode_one(mn1->x[0]);
	return (mn_ptr)mn;
}

static s_mnode* poly_make (s_mnode *constant)
{
	std_mnode *mn;
	s_mnode *cmono;
	int flag;

	flag = mnode_notzero(constant);
	mn = mstd_alloc(ST_POLY, flag? 2 : 1);
	cmono = mnode_make(ST_MONO, constant);
	if (cmono->type == ST_VOID) {
		unlink_mnode((mn_ptr)mn);
		return cmono;
	}
	if (flag) {
		mn->x[1] = cmono;
		mn->x[0] = mnode_zero(cmono);
	} else	
		mn->x[0] = cmono;

	return (mn_ptr)mn;
}
	
static s_mnode* literal2poly (s_mnode* lit, std_mnode* model)
{
	s_mnode *lmono, *lpoly;

	if (!model)
		return mnode_error(SE_ICAST, "literal2poly");
	lmono = mnode_promote(lit, model->x[0]);
	lpoly = mono2poly(lmono, NULL);
	unlink_mnode(lmono);
	return lpoly;
}

static s_mnode* mono2poly (s_mnode* mono, std_mnode* model)
{
	std_mnode* P;
	s_mnode* pmono;

	if (model)
		pmono = mnode_promote(mono, model->x[0]);
	else
		pmono = copy_mnode(mono);
	
	if (mnode_notzero(pmono)) {
		P = mstd_alloc(ST_POLY, 2);
		P->x[0] = mnode_zero(pmono);
		P->x[1] = copy_mnode(pmono);
	} else {
		P = mstd_alloc(ST_POLY, 1);
		P->x[0] = copy_mnode(pmono);
	}
	unlink_mnode(pmono);
	return (mn_ptr)P;
}

static s_mnode* poly2poly (std_mnode* P, std_mnode* model)
{
	std_mnode* Q;
	s_mnode *mmodel;
	int i, length;

	if (!model)
		return copy_mnode((mn_ptr)P);
	length = P->length;
	mmodel = model->x[0];
	Q = mstd_alloc(ST_POLY, length);
	Q->x[0] = copy_mnode(mmodel);
	for (i = 1; i < length; i++)
		Q->x[i] = mnode_promote(P->x[i], mmodel);
	return (mn_ptr) Q;
}

static s_mnode* poly_negate (std_mnode* mn1)
{
	std_mnode* mn;
	int i, length;

	length = mn1->length;
	mn = mstd_alloc(ST_POLY, length);
	mn->x[0] = copy_mnode(mn1->x[0]);
	for (i = 1; i < length; i++)
		mn->x[i] = mnode_negate(mn1->x[i]);
	return (mn_ptr)mn;
}

static int poly_notzero (std_mnode* mn1)
{
	return (mn1->length != 1);
}

static gr_string* poly_stringify (std_mnode* mn1)
{
	gr_string *grs, *grterm;
	char first;
	int i;

	if (mn1->length == 1)
		return mnode_stringify(mn1->x[0]);
	grs = new_gr_string(0);
	for (i = 1; i < mn1->length; i++) {
		grterm = mnode_stringify(mn1->x[i]);
		/* Does it begin with a sign? */
		first = grterm->s[0];
		if (first != '+' && first != '-')
			grs = grs_append1(grs, '+');
		grs = grs_append(grs, grterm->s, grterm->len);
		free(grterm);
	}
	return grs;
}

static s_mnode* poly_add (std_mnode* mn1, std_mnode* mn2)
{
	int diff, len, len1, len2;
	s_mnode **sum, **p1, **p2, **p, **p1_end, **p2_end;
	std_mnode *mn;
	extern int mono_compare (s_mnode*, s_mnode*);
	extern s_mnode* mono_add_sim (s_mnode*, s_mnode*);

	if ((len1 = mn1->length) == 1)
		return copy_mnode((mn_ptr)mn2);
	if ((len2 = mn2->length) == 1)
		return copy_mnode((mn_ptr)mn1);
	p = sum = alloca((len1+len2-2) * sizeof(s_mnode*));
	p1 = &mn1->x[1]; p1_end = &mn1->x[len1];
	p2 = &mn2->x[1]; p2_end = &mn2->x[len2];
	while (1) {
		if (p1 == p1_end) {
			/* Copy the remaining terms of mn2 */
			while (p2 < p2_end)
				*p++ = copy_mnode(*p2++);
			break;
		}
		if (p2 == p2_end) {
			/* Copy the remaining terms of mn1 */
			while (p1 < p1_end)
				*p++ = copy_mnode(*p1++);
			break;
		}
		assert(p1 < p1_end && p2 < p2_end);
		diff = mono_compare(*p1, *p2);
		if (diff < 0) {
			*p++ = copy_mnode(*p1++);
			continue;
		} else if (diff > 0) {
			*p++ = copy_mnode(*p2++);
			continue;
		}
		*p = mono_add_sim(*p1, *p2);
		if (!mnode_notzero(*p)) {
			/* The sum is the zero monomial; don't write it */
			unlink_mnode(*p);
			--p;
		}
		p++; p1++; p2++;
	}
	len = p - sum;
	assert(len <= len1+len2-2);
	mn = mstd_alloc(ST_POLY, len+1);
	mn->x[0] = copy_mnode(mn1->x[0]);
	memcpy(&mn->x[1], sum, len * sizeof(s_mnode*));
	return (mn_ptr)mn;
}

static std_mnode* poly_split_mul (std_mnode* mn1, s_mnode** list, int length)
{
	s_mnode **p, **a0, **a1, **p1, *tmp1, *tmp2;
	std_mnode *mn;

	if (length > 1) {
		int part1 = length / 2;
		std_mnode *p1, *p2;
		
		p1 = poly_split_mul(mn1, list, part1);
		p2 = poly_split_mul(mn1, list + part1, length - part1);
		mn = (smn_ptr)poly_add(p1, p2);
		unlink_mnode((mn_ptr)p1);
		unlink_mnode((mn_ptr)p2);
		return mn;
	}
	p = p1 = alloca((mn1->length) * sizeof(s_mnode*));
	*p1++ = copy_mnode(mn1->x[0]);
	a0 = &mn1->x[0];
	a1 = &mn1->x[mn1->length];
	tmp1 = list[0];
	while (++a0 < a1) {
		tmp2 = mnode_mul(*a0, tmp1);
		if (mnode_notzero(tmp2))
			*p1++ = tmp2;
		else
			unlink_mnode(tmp2);
	}
	length = p1 - p;
	assert(length <= mn1->length);
	mn = mstd_alloc(ST_POLY, length);
	memcpy(mn->x, p, length * sizeof(s_mnode*));
	return mn;
}

static s_mnode* poly_mul (std_mnode* mn1, std_mnode* mn2)
{
	int len1, len2;

	if ((len1 = mn1->length) == 1) {
		/* mn1 is zero */
		return copy_mnode((mn_ptr)mn1);
	}
	if ((len2 = mn2->length) == 1) {
		/* mn2 is zero */
		return copy_mnode((mn_ptr)mn2);
	}
	if (len1 > len2)
		return (mn_ptr)poly_split_mul(mn1, &mn2->x[1], len2 - 1);
	else
		return (mn_ptr)poly_split_mul(mn2, &mn1->x[1], len1 - 1);
}

static s_mnode* poly_div (s_mnode* mn1, s_mnode* mn2)
{
	int len1, len2;
	s_mnode *best1, *best2, *q, *quot, *rem, *t1, *t2, *t3;

	len2 = ((smn_ptr)mn2)->length;
	if (len2 == 1)
		return mnode_error(SE_DIVZERO, "poly_div");
	quot = mnode_zero(mn1);
	rem = copy_mnode(mn1);
next_term:
	len1 = ((smn_ptr)rem)->length;
	if (len1 == 1) {
		/* The remainder is zero */
		unlink_mnode(rem);
		return quot;
	}
	/*
	 * We know that the monomials of mn1 and mn2 are stored in
	 * increasing lexicographic order, and that this order is
	 * compatible with multiplication. Therefore, we essentially
	 * have to examine the most-significant (i.e., last) terms
	 * of these polynomials.
	 */
	best1 = ((smn_ptr)rem)->x[len1-1];
	best2 = ((smn_ptr)mn2)->x[len2-1];
	q = mnode_div(best1, best2);
	if (!mnode_notzero(q)) {
		/* The quotient is zero */
		unlink_mnode(q);
		unlink_mnode(rem);
		return quot;
	}
	t1 = mnode_promote(q, rem); unlink_mnode(q);
	t2 = mnode_add(quot, t1); unlink_mnode(quot); quot = t2;
	t2 = mnode_mul(t1, mn2); unlink_mnode(t1);
	t3 = mnode_sub(rem, t2); unlink_mnode(rem);
	unlink_mnode(t2); rem = t3;
	goto next_term;
}

s_mnode* poly_subs (std_mnode* poly, s_mnode* umono, s_mnode* exp)
{
	s_mnode *list, *result;

	/* The second argument must be a monomial */
	if (umono->type == ST_POLY && ((smn_ptr)umono)->length == 2)
		umono = ((smn_ptr)umono)->x[1];
	if (umono->type != ST_MONO || exp->type != ST_POLY)
		return mnode_error(SE_TCONFL, "poly_subs");
	list = decompose_powers_umono(poly, umono);
	result = upoly_eval(list, exp);
	unlink_mnode(list);
	return result;
}

static s_mnode* poly_gcd (std_mnode* mn1, std_mnode* mn2)
{
	s_mnode *lit, *lp, *mn1a, *mn2a, *gcda, *result;
	std_mnode *expanded_mono;

	if (mn1->length == 1) {
		/* mn1 is zero */
		return copy_mnode((mn_ptr)mn2);
	}
	if (mn2->length == 1) {
		/* mn2 is zero */
		return copy_mnode((mn_ptr)mn1);
	}
	/* Choose a literal in mn1 */
	expanded_mono = mono_unpack(mn1->x[mn1->length-1]);
	if (expanded_mono->length < 3) {
		/* Oh, mn1 is constant. What about mn2 ? */
		s_mnode *coef1, *coef2;
		coef1 = copy_mnode(expanded_mono->x[0]);
		unlink_mnode((mn_ptr)expanded_mono);
		expanded_mono = mono_unpack(mn2->x[mn2->length-1]);
		if (expanded_mono->length < 3) {
			/* So is mn2... */
			coef2 = copy_mnode(expanded_mono->x[0]);
			unlink_mnode((mn_ptr)expanded_mono);
			gcda = mnode_gcd(coef1,coef2);
			unlink_mnode(coef1); unlink_mnode(coef2);
			result = mnode_promote(gcda, (mn_ptr)mn1);
			unlink_mnode(gcda);
			return result;
		}
		unlink_mnode(coef1);
	}
	lit = mnode_promote(expanded_mono->x[1], mn1->x[0]);
	unlink_mnode((mn_ptr)expanded_mono);
	/* Write mn1 and mn2 as univariate polynomials */
	mn1a = decompose_powers_umono(mn1, lit);
	mn2a = decompose_powers_umono(mn2, lit);
	gcda = mnode_gcd(mn1a, mn2a);
	unlink_mnode(mn1a); unlink_mnode(mn2a);
	/* Convert back to a multivariate polynomial */
	lp = mnode_promote(lit, (mn_ptr)mn1);
	unlink_mnode(lit);
	result = upoly_eval(gcda, lp);
	unlink_mnode(gcda); unlink_mnode(lp);
	return result;
}

s_mnode* poly_sylvester (std_mnode* mn1, std_mnode* mn2, s_mnode* lit)
{
	s_mnode *mn1a, *mn2a, *sylva;
	extern s_mnode* upoly_sylvester(s_mnode*,s_mnode*);

	/* The third argument must be a monomial */
	if (lit->type == ST_POLY && ((smn_ptr)lit)->length == 2)
		lit = ((smn_ptr)lit)->x[1];
	if (lit->type != ST_MONO)
		return mnode_error(SE_TCONFL, "poly_sylvester");
	/* Write mn1 and mn2 as univariate polynomials */
	mn1a = decompose_powers_umono(mn1, lit);
	mn2a = decompose_powers_umono(mn2, lit);
	sylva = upoly_sylvester(mn1a, mn2a);
	unlink_mnode(mn1a); unlink_mnode(mn2a);
	return sylva;
}

s_mnode* poly_diff (std_mnode* mn1, s_mnode* lit)
{
	s_mnode *mn1a, *md1a, *md1, *x;
	extern s_mnode* upoly_diff(s_mnode*);

	/* The second argument must be a monomial */
	if (lit->type == ST_POLY && ((smn_ptr)lit)->length == 2)
		lit = ((smn_ptr)lit)->x[1];
	if (lit->type != ST_MONO)
		return mnode_error(SE_TCONFL, "poly_diff");
	mn1a = decompose_powers_umono(mn1, lit);
	md1a = upoly_diff(mn1a); unlink_mnode(mn1a);
	x = mnode_promote(lit, (mn_ptr)mn1); 
	md1 = upoly_eval(md1a, x);
	unlink_mnode(md1a); unlink_mnode(x);
	return md1;
}
