Skip to content

Commit 188a3ac

Browse files
committed
Clarify and optimize top-ports checking
1 parent 92b68cb commit 188a3ac

File tree

1 file changed

+98
-97
lines changed

1 file changed

+98
-97
lines changed

services.cc

+98-97
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,8 @@
7272

7373
#include <list>
7474
#include <map>
75+
#include <iterator>
76+
#include <algorithm>
7577

7678
/* This structure is the key for looking up services in the
7779
port/proto -> service map. */
@@ -332,30 +334,32 @@ static int port_compare(const void *a, const void *b) {
332334
}
333335

334336

337+
template <typename T>
338+
class C_array_iterator: public std::iterator<std::random_access_iterator_tag, T, std::ptrdiff_t> {
339+
T *ptr;
340+
public:
341+
C_array_iterator(T *_ptr=NULL) : ptr(_ptr) {}
342+
C_array_iterator(const C_array_iterator &other) : ptr(other.ptr) {}
343+
C_array_iterator& operator=(T *_ptr) {ptr = _ptr; return *this;}
344+
C_array_iterator& operator++() {ptr++; return *this;}
345+
C_array_iterator operator++(int) {C_array_iterator retval = *this; ++(*this); return retval;}
346+
C_array_iterator& operator--() {ptr--; return *this;}
347+
C_array_iterator operator--(int) {C_array_iterator retval = *this; --(*this); return retval;}
348+
bool operator==(C_array_iterator &other) const {return ptr == other.ptr;}
349+
bool operator!=(C_array_iterator &other) const {return !(*this == other);}
350+
bool operator<(C_array_iterator &other) const {return ptr < other.ptr;}
351+
C_array_iterator& operator+=(std::ptrdiff_t n) {ptr += n; return *this;}
352+
C_array_iterator& operator-=(std::ptrdiff_t n) {ptr -= n; return *this;}
353+
std::ptrdiff_t operator+(const C_array_iterator &other) {return ptr + other.ptr;}
354+
std::ptrdiff_t operator-(const C_array_iterator &other) {return ptr - other.ptr;}
355+
T& operator*() const {return *ptr;}
356+
};
335357

336-
// is_port_member() returns true if serv is an element of ptsdata.
337-
// This could be implemented MUCH more efficiently but it should only be
338-
// called when you use a non-default top-ports or port-ratio value TOGETHER WITH
339-
// a -p portlist.
340-
341-
static bool is_port_member(const struct scan_lists *ptsdata, const struct service_node *serv) {
342-
int i;
343-
344-
if (strcmp(serv->s_proto, "tcp") == 0) {
345-
for (i=0; i<ptsdata->tcp_count; i++)
346-
if (serv->s_port == ptsdata->tcp_ports[i])
347-
return true;
348-
} else if (strcmp(serv->s_proto, "udp") == 0) {
349-
for (i=0; i<ptsdata->udp_count; i++)
350-
if (serv->s_port == ptsdata->udp_ports[i])
351-
return true;
352-
} else if (strcmp(serv->s_proto, "sctp") == 0) {
353-
for (i=0; i<ptsdata->sctp_count; i++)
354-
if (serv->s_port == ptsdata->sctp_ports[i])
355-
return true;
356-
}
357-
358-
return false;
358+
// is_port_member() returns true if serv->s_port is an element of pts.
359+
static bool is_port_member(unsigned short *pts, int count, const struct service_node *serv) {
360+
C_array_iterator<unsigned short> begin = pts;
361+
C_array_iterator<unsigned short> end = pts + count;
362+
return std::binary_search(begin, end, serv->s_port);
359363
}
360364

361365
// gettoppts() sets its third parameter, a scan_list, with the most
@@ -375,7 +379,6 @@ static bool is_port_member(const struct scan_lists *ptsdata, const struct servic
375379
// function if o.TCPScan() || o.UDPScan() || o.SCTPScan()
376380

377381
void gettoppts(double level, const char *portlist, struct scan_lists * ports, const char *exclude_ports) {
378-
int ti=0, ui=0, si=0;
379382
struct scan_lists ptsdata = { 0 };
380383
bool ptsdata_initialized = false;
381384
const struct service_node *current;
@@ -421,85 +424,83 @@ void gettoppts(double level, const char *portlist, struct scan_lists * ports, co
421424
if (ptsdata_initialized && exclude_ports)
422425
removepts(exclude_ports, &ptsdata);
423426

424-
if (level < 1) {
425-
for (i = services_by_ratio.begin(); i != services_by_ratio.end(); i++) {
426-
current = &(*i);
427-
if (ptsdata_initialized && !is_port_member(&ptsdata, current))
428-
continue;
429-
if (current->ratio >= level) {
430-
if (o.TCPScan() && strcmp(current->s_proto, "tcp") == 0)
431-
ports->tcp_count++;
432-
else if (o.UDPScan() && strcmp(current->s_proto, "udp") == 0)
433-
ports->udp_count++;
434-
else if (o.SCTPScan() && strcmp(current->s_proto, "sctp") == 0)
435-
ports->sctp_count++;
436-
} else {
437-
break;
438-
}
439-
}
440-
441-
if (ports->tcp_count)
442-
ports->tcp_ports = (unsigned short *)safe_zalloc(ports->tcp_count * sizeof(unsigned short));
443-
444-
if (ports->udp_count)
445-
ports->udp_ports = (unsigned short *)safe_zalloc(ports->udp_count * sizeof(unsigned short));
446-
447-
if (ports->sctp_count)
448-
ports->sctp_ports = (unsigned short *)safe_zalloc(ports->sctp_count * sizeof(unsigned short));
449-
450-
ports->prots = NULL;
451-
452-
for (i = services_by_ratio.begin(); i != services_by_ratio.end(); i++) {
453-
current = &(*i);
454-
if (ptsdata_initialized && !is_port_member(&ptsdata, current))
455-
continue;
456-
if (current->ratio >= level) {
457-
if (o.TCPScan() && strcmp(current->s_proto, "tcp") == 0)
458-
ports->tcp_ports[ti++] = current->s_port;
459-
else if (o.UDPScan() && strcmp(current->s_proto, "udp") == 0)
460-
ports->udp_ports[ui++] = current->s_port;
461-
else if (o.SCTPScan() && strcmp(current->s_proto, "sctp") == 0)
462-
ports->sctp_ports[si++] = current->s_port;
463-
} else {
464-
break;
465-
}
466-
}
467-
} else if (level >= 1) {
427+
/* Max number of ports for each protocol cannot be more than the minimum of:
428+
* 1. all of them (65536)
429+
* 2. requested ports (ptsdata)
430+
* 3. the number in services db (numXXXports)
431+
*/
432+
int tcpmax = o.TCPScan() ? (ptsdata_initialized ? ptsdata.tcp_count : 65536) : 0;
433+
tcpmax = MIN(tcpmax, numtcpports);
434+
int udpmax = o.UDPScan() ? (ptsdata_initialized ? ptsdata.udp_count : 65536) : 0;
435+
udpmax = MIN(udpmax, numudpports);
436+
int sctpmax = o.SCTPScan() ? (ptsdata_initialized ? ptsdata.sctp_count : 65536) : 0;
437+
sctpmax = MIN(sctpmax, numsctpports);
438+
439+
// If level is positive integer, it's the max number of ports.
440+
if (level >= 1) {
468441
if (level > 65536)
469442
fatal("Level argument to gettoppts (%g) is too large", level);
443+
tcpmax = MIN((int) level, tcpmax);
444+
udpmax = MIN((int) level, udpmax);
445+
sctpmax = MIN((int) level, sctpmax);
446+
// Now force the ratio comparison to always be true:
447+
level = 0;
448+
}
449+
else if (level <= 0) {
450+
fatal("Argument to gettoppts (%g) should be a positive ratio below 1 or an integer of 1 or higher", level);
451+
}
452+
// else level is a ratio between 0 and 1
470453

471-
if (o.TCPScan()) {
472-
ports->tcp_count = MIN((int) level, numtcpports);
473-
ports->tcp_ports = (unsigned short *)safe_zalloc(ports->tcp_count * sizeof(unsigned short));
474-
}
475-
if (o.UDPScan()) {
476-
ports->udp_count = MIN((int) level, numudpports);
477-
ports->udp_ports = (unsigned short *)safe_zalloc(ports->udp_count * sizeof(unsigned short));
478-
}
479-
if (o.SCTPScan()) {
480-
ports->sctp_count = MIN((int) level, numsctpports);
481-
ports->sctp_ports = (unsigned short *)safe_zalloc(ports->sctp_count * sizeof(unsigned short));
482-
}
454+
// These could be 0/false if the scan type was not requested.
455+
if (tcpmax) {
456+
ports->tcp_ports = (unsigned short *)safe_zalloc(tcpmax * sizeof(unsigned short));
457+
}
458+
if (udpmax) {
459+
ports->udp_ports = (unsigned short *)safe_zalloc(udpmax * sizeof(unsigned short));
460+
}
461+
if (sctpmax) {
462+
ports->sctp_ports = (unsigned short *)safe_zalloc(sctpmax * sizeof(unsigned short));
463+
}
483464

484-
ports->prots = NULL;
465+
ports->prots = NULL;
485466

486-
for (i = services_by_ratio.begin(); i != services_by_ratio.end(); i++) {
487-
current = &(*i);
488-
if (ptsdata_initialized && !is_port_member(&ptsdata, current))
489-
continue;
490-
if (o.TCPScan() && strcmp(current->s_proto, "tcp") == 0 && ti < ports->tcp_count)
491-
ports->tcp_ports[ti++] = current->s_port;
492-
else if (o.UDPScan() && strcmp(current->s_proto, "udp") == 0 && ui < ports->udp_count)
493-
ports->udp_ports[ui++] = current->s_port;
494-
else if (o.SCTPScan() && strcmp(current->s_proto, "sctp") == 0 && si < ports->sctp_count)
495-
ports->sctp_ports[si++] = current->s_port;
467+
// Loop until we get enough or run out of candidates
468+
for (i = services_by_ratio.begin(); i != services_by_ratio.end() && (tcpmax || udpmax || sctpmax); i++) {
469+
current = &(*i);
470+
if (current->ratio < level) {
471+
break;
496472
}
497-
498-
if (ti < ports->tcp_count) ports->tcp_count = ti;
499-
if (ui < ports->udp_count) ports->udp_count = ui;
500-
if (si < ports->sctp_count) ports->sctp_count = si;
501-
} else
502-
fatal("Argument to gettoppts (%g) should be a positive ratio below 1 or an integer of 1 or higher", level);
473+
switch (current->s_proto[0]) {
474+
case 't':
475+
if (tcpmax && strcmp(current->s_proto, "tcp") == 0
476+
&& (!ptsdata_initialized ||
477+
is_port_member(ptsdata.tcp_ports, ptsdata.tcp_count, current))
478+
) {
479+
ports->tcp_ports[ports->tcp_count++] = current->s_port;
480+
tcpmax--;
481+
}
482+
break;
483+
case 'u':
484+
if (udpmax && strcmp(current->s_proto, "udp") == 0
485+
&& (!ptsdata_initialized ||
486+
is_port_member(ptsdata.udp_ports, ptsdata.udp_count, current))
487+
) {
488+
ports->udp_ports[ports->udp_count++] = current->s_port;
489+
udpmax--;
490+
}
491+
break;
492+
case 's':
493+
if (sctpmax && strcmp(current->s_proto, "sctp") == 0
494+
&& (!ptsdata_initialized ||
495+
is_port_member(ptsdata.sctp_ports, ptsdata.sctp_count, current))
496+
)
497+
ports->sctp_ports[ports->sctp_count++] = current->s_port;
498+
sctpmax--;
499+
break;
500+
default:
501+
break;
502+
}
503+
}
503504

504505
if (ptsdata_initialized) {
505506
free_scan_lists(&ptsdata);

0 commit comments

Comments
 (0)