72
72
73
73
#include < list>
74
74
#include < map>
75
+ #include < iterator>
76
+ #include < algorithm>
75
77
76
78
/* This structure is the key for looking up services in the
77
79
port/proto -> service map. */
@@ -332,30 +334,32 @@ static int port_compare(const void *a, const void *b) {
332
334
}
333
335
334
336
337
+ template <typename T>
338
+ class C_array_iterator : public std ::iterator<std::random_access_iterator_tag, T, std::ptrdiff_t > {
339
+ T *ptr;
340
+ public:
341
+ C_array_iterator (T *_ptr=NULL ) : ptr(_ptr) {}
342
+ C_array_iterator (const C_array_iterator &other) : ptr(other.ptr) {}
343
+ C_array_iterator& operator =(T *_ptr) {ptr = _ptr; return *this ;}
344
+ C_array_iterator& operator ++() {ptr++; return *this ;}
345
+ C_array_iterator operator ++(int ) {C_array_iterator retval = *this ; ++(*this ); return retval;}
346
+ C_array_iterator& operator --() {ptr--; return *this ;}
347
+ C_array_iterator operator --(int ) {C_array_iterator retval = *this ; --(*this ); return retval;}
348
+ bool operator ==(C_array_iterator &other) const {return ptr == other.ptr ;}
349
+ bool operator !=(C_array_iterator &other) const {return !(*this == other);}
350
+ bool operator <(C_array_iterator &other) const {return ptr < other.ptr ;}
351
+ C_array_iterator& operator +=(std::ptrdiff_t n) {ptr += n; return *this ;}
352
+ C_array_iterator& operator -=(std::ptrdiff_t n) {ptr -= n; return *this ;}
353
+ std::ptrdiff_t operator +(const C_array_iterator &other) {return ptr + other.ptr ;}
354
+ std::ptrdiff_t operator -(const C_array_iterator &other) {return ptr - other.ptr ;}
355
+ T& operator *() const {return *ptr;}
356
+ };
335
357
336
- // is_port_member() returns true if serv is an element of ptsdata.
337
- // This could be implemented MUCH more efficiently but it should only be
338
- // called when you use a non-default top-ports or port-ratio value TOGETHER WITH
339
- // a -p portlist.
340
-
341
- static bool is_port_member (const struct scan_lists *ptsdata, const struct service_node *serv) {
342
- int i;
343
-
344
- if (strcmp (serv->s_proto , " tcp" ) == 0 ) {
345
- for (i=0 ; i<ptsdata->tcp_count ; i++)
346
- if (serv->s_port == ptsdata->tcp_ports [i])
347
- return true ;
348
- } else if (strcmp (serv->s_proto , " udp" ) == 0 ) {
349
- for (i=0 ; i<ptsdata->udp_count ; i++)
350
- if (serv->s_port == ptsdata->udp_ports [i])
351
- return true ;
352
- } else if (strcmp (serv->s_proto , " sctp" ) == 0 ) {
353
- for (i=0 ; i<ptsdata->sctp_count ; i++)
354
- if (serv->s_port == ptsdata->sctp_ports [i])
355
- return true ;
356
- }
357
-
358
- return false ;
358
+ // is_port_member() returns true if serv->s_port is an element of pts.
359
+ static bool is_port_member (unsigned short *pts, int count, const struct service_node *serv) {
360
+ C_array_iterator<unsigned short > begin = pts;
361
+ C_array_iterator<unsigned short > end = pts + count;
362
+ return std::binary_search (begin, end, serv->s_port );
359
363
}
360
364
361
365
// gettoppts() sets its third parameter, a scan_list, with the most
@@ -375,7 +379,6 @@ static bool is_port_member(const struct scan_lists *ptsdata, const struct servic
375
379
// function if o.TCPScan() || o.UDPScan() || o.SCTPScan()
376
380
377
381
void gettoppts (double level, const char *portlist, struct scan_lists * ports, const char *exclude_ports) {
378
- int ti=0 , ui=0 , si=0 ;
379
382
struct scan_lists ptsdata = { 0 };
380
383
bool ptsdata_initialized = false ;
381
384
const struct service_node *current;
@@ -421,85 +424,83 @@ void gettoppts(double level, const char *portlist, struct scan_lists * ports, co
421
424
if (ptsdata_initialized && exclude_ports)
422
425
removepts (exclude_ports, &ptsdata);
423
426
424
- if (level < 1 ) {
425
- for (i = services_by_ratio.begin (); i != services_by_ratio.end (); i++) {
426
- current = &(*i);
427
- if (ptsdata_initialized && !is_port_member (&ptsdata, current))
428
- continue ;
429
- if (current->ratio >= level) {
430
- if (o.TCPScan () && strcmp (current->s_proto , " tcp" ) == 0 )
431
- ports->tcp_count ++;
432
- else if (o.UDPScan () && strcmp (current->s_proto , " udp" ) == 0 )
433
- ports->udp_count ++;
434
- else if (o.SCTPScan () && strcmp (current->s_proto , " sctp" ) == 0 )
435
- ports->sctp_count ++;
436
- } else {
437
- break ;
438
- }
439
- }
440
-
441
- if (ports->tcp_count )
442
- ports->tcp_ports = (unsigned short *)safe_zalloc (ports->tcp_count * sizeof (unsigned short ));
443
-
444
- if (ports->udp_count )
445
- ports->udp_ports = (unsigned short *)safe_zalloc (ports->udp_count * sizeof (unsigned short ));
446
-
447
- if (ports->sctp_count )
448
- ports->sctp_ports = (unsigned short *)safe_zalloc (ports->sctp_count * sizeof (unsigned short ));
449
-
450
- ports->prots = NULL ;
451
-
452
- for (i = services_by_ratio.begin (); i != services_by_ratio.end (); i++) {
453
- current = &(*i);
454
- if (ptsdata_initialized && !is_port_member (&ptsdata, current))
455
- continue ;
456
- if (current->ratio >= level) {
457
- if (o.TCPScan () && strcmp (current->s_proto , " tcp" ) == 0 )
458
- ports->tcp_ports [ti++] = current->s_port ;
459
- else if (o.UDPScan () && strcmp (current->s_proto , " udp" ) == 0 )
460
- ports->udp_ports [ui++] = current->s_port ;
461
- else if (o.SCTPScan () && strcmp (current->s_proto , " sctp" ) == 0 )
462
- ports->sctp_ports [si++] = current->s_port ;
463
- } else {
464
- break ;
465
- }
466
- }
467
- } else if (level >= 1 ) {
427
+ /* Max number of ports for each protocol cannot be more than the minimum of:
428
+ * 1. all of them (65536)
429
+ * 2. requested ports (ptsdata)
430
+ * 3. the number in services db (numXXXports)
431
+ */
432
+ int tcpmax = o.TCPScan () ? (ptsdata_initialized ? ptsdata.tcp_count : 65536 ) : 0 ;
433
+ tcpmax = MIN (tcpmax, numtcpports);
434
+ int udpmax = o.UDPScan () ? (ptsdata_initialized ? ptsdata.udp_count : 65536 ) : 0 ;
435
+ udpmax = MIN (udpmax, numudpports);
436
+ int sctpmax = o.SCTPScan () ? (ptsdata_initialized ? ptsdata.sctp_count : 65536 ) : 0 ;
437
+ sctpmax = MIN (sctpmax, numsctpports);
438
+
439
+ // If level is positive integer, it's the max number of ports.
440
+ if (level >= 1 ) {
468
441
if (level > 65536 )
469
442
fatal (" Level argument to gettoppts (%g) is too large" , level);
443
+ tcpmax = MIN ((int ) level, tcpmax);
444
+ udpmax = MIN ((int ) level, udpmax);
445
+ sctpmax = MIN ((int ) level, sctpmax);
446
+ // Now force the ratio comparison to always be true:
447
+ level = 0 ;
448
+ }
449
+ else if (level <= 0 ) {
450
+ fatal (" Argument to gettoppts (%g) should be a positive ratio below 1 or an integer of 1 or higher" , level);
451
+ }
452
+ // else level is a ratio between 0 and 1
470
453
471
- if (o.TCPScan ()) {
472
- ports->tcp_count = MIN ((int ) level, numtcpports);
473
- ports->tcp_ports = (unsigned short *)safe_zalloc (ports->tcp_count * sizeof (unsigned short ));
474
- }
475
- if (o.UDPScan ()) {
476
- ports->udp_count = MIN ((int ) level, numudpports);
477
- ports->udp_ports = (unsigned short *)safe_zalloc (ports->udp_count * sizeof (unsigned short ));
478
- }
479
- if (o.SCTPScan ()) {
480
- ports->sctp_count = MIN ((int ) level, numsctpports);
481
- ports->sctp_ports = (unsigned short *)safe_zalloc (ports->sctp_count * sizeof (unsigned short ));
482
- }
454
+ // These could be 0/false if the scan type was not requested.
455
+ if (tcpmax) {
456
+ ports->tcp_ports = (unsigned short *)safe_zalloc (tcpmax * sizeof (unsigned short ));
457
+ }
458
+ if (udpmax) {
459
+ ports->udp_ports = (unsigned short *)safe_zalloc (udpmax * sizeof (unsigned short ));
460
+ }
461
+ if (sctpmax) {
462
+ ports->sctp_ports = (unsigned short *)safe_zalloc (sctpmax * sizeof (unsigned short ));
463
+ }
483
464
484
- ports->prots = NULL ;
465
+ ports->prots = NULL ;
485
466
486
- for (i = services_by_ratio.begin (); i != services_by_ratio.end (); i++) {
487
- current = &(*i);
488
- if (ptsdata_initialized && !is_port_member (&ptsdata, current))
489
- continue ;
490
- if (o.TCPScan () && strcmp (current->s_proto , " tcp" ) == 0 && ti < ports->tcp_count )
491
- ports->tcp_ports [ti++] = current->s_port ;
492
- else if (o.UDPScan () && strcmp (current->s_proto , " udp" ) == 0 && ui < ports->udp_count )
493
- ports->udp_ports [ui++] = current->s_port ;
494
- else if (o.SCTPScan () && strcmp (current->s_proto , " sctp" ) == 0 && si < ports->sctp_count )
495
- ports->sctp_ports [si++] = current->s_port ;
467
+ // Loop until we get enough or run out of candidates
468
+ for (i = services_by_ratio.begin (); i != services_by_ratio.end () && (tcpmax || udpmax || sctpmax); i++) {
469
+ current = &(*i);
470
+ if (current->ratio < level) {
471
+ break ;
496
472
}
497
-
498
- if (ti < ports->tcp_count ) ports->tcp_count = ti;
499
- if (ui < ports->udp_count ) ports->udp_count = ui;
500
- if (si < ports->sctp_count ) ports->sctp_count = si;
501
- } else
502
- fatal (" Argument to gettoppts (%g) should be a positive ratio below 1 or an integer of 1 or higher" , level);
473
+ switch (current->s_proto [0 ]) {
474
+ case ' t' :
475
+ if (tcpmax && strcmp (current->s_proto , " tcp" ) == 0
476
+ && (!ptsdata_initialized ||
477
+ is_port_member (ptsdata.tcp_ports , ptsdata.tcp_count , current))
478
+ ) {
479
+ ports->tcp_ports [ports->tcp_count ++] = current->s_port ;
480
+ tcpmax--;
481
+ }
482
+ break ;
483
+ case ' u' :
484
+ if (udpmax && strcmp (current->s_proto , " udp" ) == 0
485
+ && (!ptsdata_initialized ||
486
+ is_port_member (ptsdata.udp_ports , ptsdata.udp_count , current))
487
+ ) {
488
+ ports->udp_ports [ports->udp_count ++] = current->s_port ;
489
+ udpmax--;
490
+ }
491
+ break ;
492
+ case ' s' :
493
+ if (sctpmax && strcmp (current->s_proto , " sctp" ) == 0
494
+ && (!ptsdata_initialized ||
495
+ is_port_member (ptsdata.sctp_ports , ptsdata.sctp_count , current))
496
+ )
497
+ ports->sctp_ports [ports->sctp_count ++] = current->s_port ;
498
+ sctpmax--;
499
+ break ;
500
+ default :
501
+ break ;
502
+ }
503
+ }
503
504
504
505
if (ptsdata_initialized) {
505
506
free_scan_lists (&ptsdata);
0 commit comments