Commit 69d10132 authored by Toke Høiland-Jørgensen's avatar Toke Høiland-Jørgensen Committed by Ondrej Zajicek (work)
Browse files

Babel: Refactor TLV parsing code for easier reuse

In preparation for adding authentication checks, refactor the TLV
walking code so it can be reused for a separate pass of the packet
for authentication checks.
parent 589f7d1e
Loading
Loading
Loading
Loading
+107 −64
Original line number Diff line number Diff line
@@ -120,8 +120,19 @@ struct babel_subtlv_source_prefix {
#define BABEL_UF_DEF_PREFIX	0x80
#define BABEL_UF_ROUTER_ID	0x40

struct babel_parse_state;
struct babel_write_state;

struct babel_tlv_data {
  u8 min_length;
  int (*read_tlv)(struct babel_tlv *hdr, union babel_msg *m, struct babel_parse_state *state);
  uint (*write_tlv)(struct babel_tlv *hdr, union babel_msg *m, struct babel_write_state *state, uint max_len);
  void (*handle_tlv)(union babel_msg *m, struct babel_iface *ifa);
};

struct babel_parse_state {
  const struct babel_tlv_data* (*get_tlv_data)(u8 type);
  const struct babel_tlv_data* (*get_subtlv_data)(u8 type);
  struct babel_proto *proto;
  struct babel_iface *ifa;
  ip_addr saddr;
@@ -167,6 +178,37 @@ struct babel_write_state {

#define NET_SIZE(n) BYTES(net_pxlen(n))


/* Helper macros to loop over a series of TLVs.
 * @start      pointer to first TLV (void * or struct babel_tlv *)
 * @end        byte * pointer to TLV stream end
 * @tlv        struct babel_tlv pointer used as iterator
 * @frame_err  boolean (u8) that will be set to 1 if a frame error occurred
 * @saddr      source addr for use in log output
 * @ifname     ifname for use in log output
 */
#define WALK_TLVS(start, end, tlv, frame_err, saddr, ifname)    \
  for (tlv = start;						\
       (byte *)tlv < end;                                               \
       tlv = NEXT_TLV(tlv))						\
  {									\
    byte *loop_pos;							\
    /* Ugly special case */						\
    if (tlv->type == BABEL_TLV_PAD1)					\
      continue;                                                         \
									\
    /* The end of the common TLV header */				\
    loop_pos = (byte *)tlv + sizeof(struct babel_tlv);			\
    if ((loop_pos > end) || (loop_pos + tlv->length > end))             \
    {                                                                   \
      LOG_PKT("Bad TLV from %I via %s type %d pos %d - framing error",  \
	      saddr, ifname, tlv->type, (byte *)tlv - (byte *)start);   \
      frame_err = 1;                                                    \
      break;                                                            \
    }

#define WALK_TLVS_END }

static inline uint
bytes_equal(u8 *b1, u8 *b2, uint maxlen)
{
@@ -255,13 +297,6 @@ static uint babel_write_route_request(struct babel_tlv *hdr, union babel_msg *ms
static uint babel_write_seqno_request(struct babel_tlv *hdr, union babel_msg *msg, struct babel_write_state *state, uint max_len);
static int babel_write_source_prefix(struct babel_tlv *hdr, net_addr *net, uint max_len);

struct babel_tlv_data {
  u8 min_length;
  int (*read_tlv)(struct babel_tlv *hdr, union babel_msg *m, struct babel_parse_state *state);
  uint (*write_tlv)(struct babel_tlv *hdr, union babel_msg *m, struct babel_write_state *state, uint max_len);
  void (*handle_tlv)(union babel_msg *m, struct babel_iface *ifa);
};

static const struct babel_tlv_data tlv_data[BABEL_TLV_MAX] = {
  [BABEL_TLV_ACK_REQ] = {
    sizeof(struct babel_tlv_ack_req),
@@ -319,6 +354,30 @@ static const struct babel_tlv_data tlv_data[BABEL_TLV_MAX] = {
  },
};

static const struct babel_tlv_data *get_packet_tlv_data(u8 type)
{
  return type < sizeof(tlv_data) / sizeof(*tlv_data) ? &tlv_data[type] : NULL;
}

static const struct babel_tlv_data source_prefix_tlv_data = {
  sizeof(struct babel_subtlv_source_prefix),
  babel_read_source_prefix,
  NULL,
  NULL
};

static const struct babel_tlv_data *get_packet_subtlv_data(u8 type)
{
  switch(type)
  {
  case BABEL_SUBTLV_SOURCE_PREFIX:
    return &source_prefix_tlv_data;

  default:
    return NULL;
  }
}

static int
babel_read_ack_req(struct babel_tlv *hdr, union babel_msg *m,
		   struct babel_parse_state *state)
@@ -1083,69 +1142,65 @@ babel_write_source_prefix(struct babel_tlv *hdr, net_addr *n, uint max_len)
  return len;
}


static inline int
babel_read_subtlvs(struct babel_tlv *hdr,
		   union babel_msg *msg,
		   struct babel_parse_state *state)
{
  const struct babel_tlv_data *tlv_data;
  struct babel_proto *p = state->proto;
  struct babel_tlv *tlv;
  byte *pos, *end = (byte *) hdr + TLV_LENGTH(hdr);
  byte *end = (byte *) hdr + TLV_LENGTH(hdr);
  u8 frame_err = 0;
  int res;

  for (tlv = (void *) hdr + state->current_tlv_endpos;
       (byte *) tlv < end;
       tlv = NEXT_TLV(tlv))
  WALK_TLVS((void *)hdr + state->current_tlv_endpos, end, tlv, frame_err,
            state->saddr, state->ifa->ifname)
  {
    /* Ugly special case */
    if (tlv->type == BABEL_TLV_PAD1)
    if (tlv->type == BABEL_SUBTLV_PADN)
      continue;

    /* The end of the common TLV header */
    pos = (byte *)tlv + sizeof(struct babel_tlv);
    if ((pos > end) || (pos + tlv->length > end))
      return PARSE_ERROR;

    /*
     * The subtlv type space is non-contiguous (due to the mandatory bit), so
     * use a switch for dispatch instead of the mapping array we use for TLVs
     */
    switch (tlv->type)
    if (!state->get_subtlv_data ||
        !(tlv_data = state->get_subtlv_data(tlv->type)) ||
        !tlv_data->read_tlv)
    {
    case BABEL_SUBTLV_SOURCE_PREFIX:
      res = babel_read_source_prefix(tlv, msg, state);
      if (res != PARSE_SUCCESS)
	return res;
      break;

    case BABEL_SUBTLV_PADN:
    default:
      /* Unknown mandatory subtlv; PARSE_IGNORE ignores the whole TLV */
      if (tlv->type >= 128)
        return PARSE_IGNORE;
      break;
      continue;
    }

    res = tlv_data->read_tlv(tlv, msg, state);
    if (res != PARSE_SUCCESS)
      return res;
  }
  WALK_TLVS_END;

  return PARSE_SUCCESS;
  return frame_err ? PARSE_ERROR : PARSE_SUCCESS;
}

static inline int
static int
babel_read_tlv(struct babel_tlv *hdr,
               union babel_msg *msg,
               struct babel_parse_state *state)
{
  const struct babel_tlv_data *tlv_data;

  if ((hdr->type <= BABEL_TLV_PADN) ||
      (hdr->type >= BABEL_TLV_MAX) ||
      !tlv_data[hdr->type].read_tlv)
      (hdr->type >= BABEL_TLV_MAX))
    return PARSE_IGNORE;

  tlv_data = state->get_tlv_data(hdr->type);

  if (!tlv_data || !tlv_data->read_tlv)
    return PARSE_IGNORE;

  if (TLV_LENGTH(hdr) < tlv_data[hdr->type].min_length)
  if (TLV_LENGTH(hdr) < tlv_data->min_length)
    return PARSE_ERROR;

  state->current_tlv_endpos = tlv_data[hdr->type].min_length;
  state->current_tlv_endpos = tlv_data->min_length;

  int res = tlv_data[hdr->type].read_tlv(hdr, msg, state);
  int res = tlv_data->read_tlv(hdr, msg, state);
  if (res != PARSE_SUCCESS)
    return res;

@@ -1330,6 +1385,7 @@ static void
babel_process_packet(struct babel_pkt_header *pkt, int len,
                     ip_addr saddr, struct babel_iface *ifa)
{
  u8 frame_err UNUSED = 0;
  struct babel_proto *p = ifa->proto;
  struct babel_tlv *tlv;
  struct babel_msg_node *msg;
@@ -1337,10 +1393,11 @@ babel_process_packet(struct babel_pkt_header *pkt, int len,
  int res;

  int plen = sizeof(struct babel_pkt_header) + get_u16(&pkt->length);
  byte *pos;
  byte *end = (byte *)pkt + plen;

  struct babel_parse_state state = {
    .get_tlv_data    = &get_packet_tlv_data,
    .get_subtlv_data = &get_packet_subtlv_data,
    .proto           = p,
    .ifa             = ifa,
    .saddr           = saddr,
@@ -1369,23 +1426,8 @@ babel_process_packet(struct babel_pkt_header *pkt, int len,

  /* First pass through the packet TLV by TLV, parsing each into internal data
     structures. */
  for (tlv = FIRST_TLV(pkt);
       (byte *)tlv < end;
       tlv = NEXT_TLV(tlv))
  WALK_TLVS(FIRST_TLV(pkt), end, tlv, frame_err, saddr, ifa->iface->name)
  {
    /* Ugly special case */
    if (tlv->type == BABEL_TLV_PAD1)
      continue;

    /* The end of the common TLV header */
    pos = (byte *)tlv + sizeof(struct babel_tlv);
    if ((pos > end) || (pos + tlv->length > end))
    {
      LOG_PKT("Bad TLV from %I via %s type %d pos %d - framing error",
	      saddr, ifa->iface->name, tlv->type, (byte *)tlv - (byte *)pkt);
      break;
    }

    msg = sl_allocz(p->msg_slab);
    res = babel_read_tlv(tlv, &msg->msg, &state);
    if (res == PARSE_SUCCESS)
@@ -1405,8 +1447,9 @@ babel_process_packet(struct babel_pkt_header *pkt, int len,
      break;
    }
  }
  WALK_TLVS_END;

  /* Parsing done, handle all parsed TLVs */
  /* Parsing done, handle all parsed TLVs, regardless of any errors */
  WALK_LIST_FIRST(msg, msgs)
  {
    if (tlv_data[msg->msg.type].handle_tlv)