#include "global.h"
#include "net.h"
#include "payload.h"


/* Check arguments */
int isValidPayload(const Payload *payload)
{
  return (payload != NULL
	  /* && payload->len >= 0 */
	  && payload->len <= PAYLOAD_LEN
	  /* && payload->idx >= 0 */
	  && payload->idx <= payload->len);
}

Payload * resetPayload(Payload *payload)
{
  memset(payload, 0, sizeof(Payload));
  return payload;
}

Payload * rewindPayload(Payload *payload)
{
  assert(isValidPayload(payload));
  payload->idx = 0;
  return payload;
}

#if 0
static
void appendPayload(Payload *t, const Payload *a)
{
  if (!isValidPayload(t) || !isValidPayload(a))
    warning("appendPayload: invalid Payload");
  memcpy(t->data + t->idx, a->data, a->len);
  t->idx += a->len;
  if (t->idx > t->len) t->len = t->idx;
} 
#endif

Payload * dumpPayload(FILE *f, Payload *payload)
{
  fprintf(f, "dump_payload: ");
  if (payload == NULL) {
    fprintf(f, "NULL\n");
    return payload;
  }
  else {
    int i;

    fprintf(f, "len=%d, idx=%d\n", payload->len, payload->idx);
    for (i=0; i<payload->len; i++) {
      int c = (u_int8) payload->data[i];
      fprintf(f, "0x%02x", c);
      if (isprint(c))
        fprintf(f, " %c ", c);
      else
        fprintf(f, " . ");
      if (i == payload->idx-1)
        fprintf(f, ">");
      else
        fprintf(f, " ");
    }
  }
  fprintf(f, "\n");
  return payload;
}

int putPayload(Payload *payload, const char *format, ...)
{
  va_list ap;

  if (! isValidPayload(payload)) {
    warning("putPayload: invalid Payload %s", payload);
    return -1;
  }
  if (format == NULL) {
    warning("putPayload: NULL format");
    return -1;
  }

  va_start(ap, format);
  payload->len = payload->idx;   /* "rewrite" mode rather than "append" */

  /* Parse format */
  while (*format) {
    switch (*format) {
    case 'c': /* char */
      {
	u_int8 c = va_arg(ap, int);  	

	payload->data[payload->idx++] = 'c';
	payload->data[payload->idx++] = c;
	break;
      }
    case 'h': /* short */
      {
	int16 h = htons(va_arg(ap, int));  	

	payload->data[payload->idx++] = 'h';
	memcpy(payload->data + payload->idx, &h, sizeof(int16));
	payload->idx += sizeof(int16);
	break;
      }
    case 'd': /* int32 */
      {
	int32 d = htonl(va_arg(ap, int));

	payload->data[payload->idx++] = 'd';
	memcpy(payload->data + payload->idx, &d, sizeof(int32));
	payload->idx += sizeof(int32);
	break;
      }
    case 'f': /* float */
      {
	float f = va_arg(ap, double);
	int32 *p = (int32*) &f;
	int32 j = htonl(*p);

	payload->data[payload->idx++] = 'f';
	memcpy(payload->data + payload->idx, &j, sizeof(int32));
	payload->idx += sizeof(int32);
	break;
      }
    case 's': /* string */
      {
	u_int8 *s = va_arg(ap, u_int8*);
	u_int16 len = strlen((char *) s);

	payload->data[payload->idx++] = 's';
	payload->data[payload->idx++] = len;
	memcpy(payload->data + payload->idx, s, len);
	payload->idx += len;
	break;
      }
    case 'n': /* NetObjectId */
      {
	NetObjectId n = va_arg(ap, NetObjectId);

	payload->data[payload->idx++] = 'n';

	/* everything is already in network format */
	memcpy(payload->data + payload->idx, &n.src_id, sizeof(n.src_id));
	payload->idx += sizeof(n.src_id);
	memcpy(payload->data + payload->idx, &n.port_id, sizeof(n.port_id));
	payload->idx += sizeof(n.port_id);
	memcpy(payload->data + payload->idx, &n.obj_id, sizeof(n.obj_id));
	payload->idx += sizeof(n.obj_id);
	break;
      }
    case 't': /* timeval */
      {
	struct timeval t = va_arg(ap, struct timeval);
	time_t sec = htonl(t.tv_sec);
	time_t usec = htonl(t.tv_usec);

	payload->data[payload->idx++] = 't';
	memcpy(payload->data + payload->idx, &sec, sizeof(int32));
	payload->idx += sizeof(int32);
	memcpy(payload->data + payload->idx, &usec, sizeof(int32));
	payload->idx += sizeof(int32);
	break;
      }	

    default: /* unknown type in format */
      warning("putPayload: invalid format '%c' in %s", *format, format);
      va_end(ap);
      return -1;
    }

    format++;
    if (payload->idx > payload->len)
      payload->len = payload->idx;

    /* check the length, if too long -> warning */
    if (payload->len >= PAYLOAD_LEN) {
      warning("putPayload: Payload too long (%d > %d)", payload->len, PAYLOAD_LEN);
      payload->len = payload->idx = 0; /* just in case */
      va_end(ap);
      return -1;
    }
  }
  va_end(ap);
  return 0;
}    

int getPayload(Payload *payload, const char *format, ...)
{
  va_list ap;
  const char *pformat = format;

  if (! isValidPayload(payload)) {
    warning("getPayload: invalid Payload: %02x%02x%02x%02x",
             payload[0], payload[1], payload[2], payload[3]); 
    return -1;
  }
  if (format == NULL) {
    warning("getPayload: NULL format"); 
    return -1;
  }

  va_start(ap, format); 

  while (*format) { 
    /* Format known ? */
    if (strchr("chdfsnt", *format) == NULL) {
      warning("getPayload invalid format [%c] in %s", *format, format);
      format++;
      continue;
    }

    /* Test matching Payload - format */
    if (payload->data[payload->idx] != *format) {
      warning("getPayload: mismatch '%c' (x'%02x') in payload, format='%s' [%s]",
	      payload->data[payload->idx], payload->data[payload->idx],
              format, pformat);
      trace(DBG_NET, "len=%d, idx=%d", payload->len, payload->idx);
      dumpPayload(stdout, payload);
      format++;
      continue; /* return -1; */
    }
    payload->idx++;	/* points data following */

    switch (*format) { 
    case 'c': /* char */ 
      { 
	u_int8 *p = va_arg(ap, u_int8*);

	memcpy(p, payload->data + payload->idx, sizeof(u_int8)); 
	payload->idx++;
	break; 
      } 
    case 'h': /* int16 */ 
      { 
	int16 h; 
	int16 *p = va_arg(ap, int16*);

	memcpy(&h, payload->data + payload->idx, sizeof(int16)); 
	*p = ntohs(h);
	payload->idx += sizeof(int16);
	break; 
      } 
    case 'd': /* int32 */ 
      { 
	int32 d; 
	int32 *p = va_arg(ap, int32*);

	memcpy(&d, payload->data + payload->idx, sizeof(int32)); 
	*p = ntohl(d);
	payload->idx += sizeof(int32);
	break; 
      } 
    case 'f': /* float */ 
      { 
	int32 f;
	float *p = va_arg(ap, float*);

	memcpy(&f, payload->data + payload->idx, sizeof(int32)); 
	f = ntohl(f);
	memcpy(p, &f, sizeof(int32));
	payload->idx += sizeof(int32);
	break; 
      } 
    case 's': /* string */ 
      { 
	/* Note: no length check */
	u_int8 *s = va_arg(ap, u_int8*);
	u_int16 len = payload->data[payload->idx++];

	memcpy(s, payload->data + payload->idx, len);
	s[len] = 0; /* NULL terminated */
	payload->idx += len;
	break; 
      }     
    case 'n': /* NetObjectId */
      {
	NetObjectId *n = va_arg(ap, NetObjectId*);

	memcpy(&n->src_id, payload->data + payload->idx, sizeof(n->src_id));
	payload->idx += sizeof(n->src_id);
	memcpy(&n->port_id, payload->data + payload->idx, sizeof(n->port_id));
	payload->idx += sizeof(n->port_id);
	memcpy(&n->obj_id, payload->data + payload->idx, sizeof(n->obj_id));
	payload->idx += sizeof(n->obj_id);
	break;
      }
    case 't': /* timeval */
      {
	time_t sec, usec;
	struct timeval *p = va_arg(ap, struct timeval*);

	memcpy(&sec, payload->data + payload->idx, sizeof(int32));
	payload->idx += sizeof(int32);
	memcpy(&usec, payload->data + payload->idx, sizeof(int32));
	payload->idx += sizeof(int32);
	p->tv_sec = ntohl(sec);
	p->tv_usec = ntohl(usec);
	break;
      }	
    default:
      warning("getPayload: format unimplemented [%c] in %s", *format, format);
      va_end(ap);
      return -1;
    }
    
    /* verify if not too far */
    if (payload->idx > payload->len) {
      warning("getPayload: past end of Payload: idx=%d len=%d", payload->idx, payload->len); 
      payload->idx = payload->len = 0; 
      va_end(ap);
      return -1; 
    }
    format++;
  }
  va_end(ap); 
  return 0; 
}

void seekPayload(Payload *payload, u_int16 idx)
{
  payload->idx = idx;
}

u_int16 tellPayload(Payload *payload)
{
  /* dumpPayload(stdout, payload); */
  return payload->idx; 
}

int tellStrInPayload(Payload *payload, const char *str)
{
  u_int16 save_idx;

  save_idx = payload->idx;
  for (payload->idx = 0; payload->idx < payload->len; ) { 
    u_int8 format = payload->data[payload->idx];

    /* Format known ? */
    if (strchr("chdfsnt", format) == NULL) {
      warning("seekPayload invalid format [%c]", format);
      payload->idx = save_idx;
      return -1;
    }
    payload->idx++;

    switch (format) { 
    case 'c': /* char */ 
	payload->idx++;
	break; 
    case 'h': /* int16 */ 
	payload->idx += sizeof(int16);
	break; 
    case 'd': /* int32 */ 
    case 'f': /* float */ 
	payload->idx += sizeof(int32);
	break; 
    case 's': /* string */ 
	if (!strncmp((char *) (payload->data + payload->idx + 1), str, strlen(str))) {
          payload->idx--;
	  return payload->idx;
        }
	payload->idx += payload->data[payload->idx++];
	break; 
    case 'n': /* NetObjectId */
	payload->idx += sizeof(u_int32);
	payload->idx += sizeof(u_int16);
	payload->idx += sizeof(u_int16);
	break;
    case 't': /* timeval */
	payload->idx += sizeof(int32);
	payload->idx += sizeof(int32);
	break;
    default:
	warning("seekPayload: format unimplemented [%c]", format);
	payload->idx = save_idx;
	return -1;
    }
  }
  warning("seekPayload: past end of Payload: idx=%d len=%d", payload->idx, payload->len); 
  payload->idx = save_idx;
  return -1; 
}

