16 #include <netinet/in.h>
19 #include <sys/types.h>
20 #include <sys/socket.h>
35 typedef SOCKET socket_type;
36 typedef int socklen_t;
37 typedef const char* send_ptr_type;
38 typedef char* recv_ptr_type;
39 typedef const char* sockopt_ptr_type;
41 static const socket_type BAD_SOCKET = INVALID_SOCKET;
44 bzero(
void* addr,
size_t len )
46 ::memset( addr, 0, len );
50 closesocket( socket_type s )
60 WORD wVersionRequested;
63 wVersionRequested = MAKEWORD( 2, 0 );
64 if ( WSAStartup( wVersionRequested, &wsaData ) != 0 )
66 throw std::runtime_error(
67 "Unable to initialize winsock library" );
77 static wsa_init _init_obj;
80 typedef int socket_type;
81 typedef ::socklen_t socklen_t;
82 typedef const void* send_ptr_type;
83 typedef void* recv_ptr_type;
84 typedef const void* sockopt_ptr_type;
86 static const socket_type BAD_SOCKET = -1;
89 bzero(
void* addr,
size_t len )
95 closesocket( socket_type s )
104 decltype(::sockaddr_in::sin_addr.s_addr ) node;
111 typedef detail::socket_type socket_type;
113 FullSocket( ) : s_( detail::BAD_SOCKET )
116 FullSocket( FullSocket&& other ) : s_{ other.release( ) }
120 explicit FullSocket( socket_type s ) : s_( s )
126 if ( s_ != detail::BAD_SOCKET )
128 detail::closesocket( s_ );
130 s_ = detail::BAD_SOCKET;
133 reset( socket_type fd )
140 socket_type tmp = s_;
141 s_ = detail::BAD_SOCKET;
152 swap( FullSocket& other )
154 socket_type tmp = other.s_;
162 return s_ != detail::BAD_SOCKET;
165 void bind(
const std::string& address );
168 listen(
int queue_depth = 100 )
171 throw std::runtime_error(
172 "Need to bind a port before listening" );
173 auto rc = ::listen(
get( ), queue_depth );
177 if ( err == EADDRINUSE )
178 throw std::runtime_error(
179 "Listen reported address in use" );
180 throw std::runtime_error(
"Error on listen" );
183 detail::socklen_t len =
sizeof( sock );
184 if ( getsockname( s_, (struct ::sockaddr*)&sock, &len ) != 0 )
186 throw std::runtime_error(
187 "Unable to retrieve socket info after listen call" );
190 return sock.sin_port;
193 listen(
const std::string& address,
int queue_depth = 100 )
196 return ntohs( listen( queue_depth ) );
199 FullSocket accept( );
201 void connect(
const std::string& target );
203 void write_all(
const char* start,
const char* end );
205 char* read_available(
char* start,
char* end );
207 template <
typename Cont >
208 static std::vector< FullSocket* >
209 select( Cont& sockets,
long secs,
long usec )
221 for (
auto& s : sockets )
223 FD_SET( s->get( ), &rd_set );
224 FD_SET( s->get( ), &wr_set );
225 nfds = ( s->get( ) > nfds ? s->get( ) : nfds );
228 int count = ::select( nfds,
232 ( secs >= 0 ? &tv :
nullptr ) );
233 std::vector< FullSocket* > results;
234 for (
auto& s : sockets )
236 if ( FD_ISSET( s->get( ), &rd_set ) ||
237 FD_ISSET( s->get( ), &wr_set ) )
239 results.push_back( s );
248 void set_option(
int id,
int val );
250 FullSocket(
const FullSocket& other );
251 FullSocket& operator=(
const FullSocket& other );
257 typedef detail::socket_type socket_type;
259 SocketHandle( ) : s_( detail::BAD_SOCKET )
263 SocketHandle( socket_type s ) : s_( s )
266 SocketHandle(
const SocketHandle& other ) =
default;
269 reset( socket_type fd )
277 socket_type tmp = s_;
278 s_ = detail::BAD_SOCKET;
289 swap( SocketHandle& other )
291 socket_type tmp = other.s_;
299 return s_ != detail::BAD_SOCKET;
302 void write_all(
const char* start,
const char* end );
304 char* read_available(
char* start,
char* end );
311 parse_address(
const std::string& addr )
313 Address result{ 0, 0 };
316 std::string::size_type inx = addr.find(
":" );
317 std::string host = addr;
318 if ( inx != std::string::npos )
320 host = addr.substr( 0, inx );
321 std::istringstream is( addr.substr( inx + 1 ) );
326 if ( host.size( ) == 0 )
328 result.node = INADDR_ANY;
332 struct hostent* hentp = ::gethostbyname( host.c_str( ) );
335 throw std::runtime_error(
"Error in gethostbyname" );
340 &result.node, *hentp->h_addr_list,
sizeof( result.node ) );
346 FullSocket::set_option(
int id,
int val )
348 int rc = ::setsockopt(
352 reinterpret_cast< detail::sockopt_ptr_type >( &val ),
355 throw std::runtime_error(
"Error setting socket options" );
359 FullSocket::accept( )
361 auto client = FullSocket( );
365 detail::bzero( &sock,
sizeof( sock ) );
366 sock.sin_family = AF_INET;
367 sock.sin_addr.s_addr = 0;
369 detail::socklen_t len = 16;
371 client.reset(::accept( s_, (struct ::sockaddr*)&sock, &len ) );
376 FullSocket::bind(
const std::string& address )
379 throw std::runtime_error(
380 "You cannot bind to a connected socket" );
383 auto addr_info = parse_address( address );
384 sock.sin_addr.s_addr = addr_info.node;
385 sock.sin_port = addr_info.port;
387 if ( ( s_ = ::socket( AF_INET, SOCK_STREAM, 0 ) ) < 0 )
389 throw std::runtime_error(
"Unable to create socket" );
391 sock.sin_family = AF_INET;
392 detail::socklen_t len = 16;
394 set_option( SO_REUSEADDR, 1 );
396 if (::bind( s_, (struct ::sockaddr*)&sock, len ) < 0 )
397 throw std::runtime_error(
398 "Unable to bind to requested address" );
402 FullSocket::connect(
const std::string& target )
404 class address_cleanup
408 operator( )( struct ::addrinfo* p )
416 throw std::runtime_error(
417 "Cannot connect a socket that is already initialized" );
419 std::string::size_type sep = target.find(
':' );
420 if ( sep == std::string::npos )
422 throw std::runtime_error(
423 "Socket connect target has no port specifier" );
425 std::string hostname = target.substr( 0, sep );
426 std::string port = target.substr( sep + 1 );
428 struct ::addrinfo hints;
429 struct ::addrinfo *result =
nullptr, *rp =
nullptr;
430 std::memset( &hints, 0,
sizeof( hints ) );
431 hints.ai_family = AF_UNSPEC;
432 hints.ai_socktype = SOCK_STREAM;
433 hints.ai_protocol = IPPROTO_TCP;
435 if (
int rc = ::getaddrinfo( hostname.c_str( ),
440 throw std::runtime_error(
"Unable to lookup address "
441 "information during connect "
444 std::unique_ptr< struct ::addrinfo, address_cleanup > result_(
446 for ( rp = result; rp !=
nullptr; rp = rp->ai_next )
448 FullSocket tmp(::socket(
449 rp->ai_family, rp->ai_socktype, rp->ai_protocol ) );
454 if (::connect( tmp.get( ), rp->ai_addr, rp->ai_addrlen ) == 0 )
460 throw std::runtime_error(
"Unable to connecto to target address" );
464 FullSocket::write_all(
const char* start,
const char* end )
467 throw std::runtime_error(
468 "Cannot send data on a socket that is not opened" );
474 static_cast< detail::send_ptr_type >( cur ),
480 if ( err == EINTR || err == EAGAIN )
482 throw std::runtime_error(
"Unable to send on socket" );
484 else if ( count == 0 )
486 throw std::runtime_error(
"Connection closed" );
496 FullSocket::read_available(
char* start,
char* end )
499 throw std::runtime_error(
500 "Cannot read data on a socket that is not opened" );
502 auto count = ::recv( s_,
503 static_cast< detail::recv_ptr_type >( start ),
508 throw std::runtime_error(
"Unable to read data from socket" );
510 return start + count;
514 SocketHandle::write_all(
const char* start,
const char* end )
517 throw std::runtime_error(
518 "Cannot send data on a socket that is not opened" );
524 static_cast< detail::send_ptr_type >( cur ),
530 if ( err == EINTR || err == EAGAIN )
532 throw std::runtime_error(
"Unable to send on socket" );
534 else if ( count == 0 )
536 throw std::runtime_error(
"Connection closed" );
546 SocketHandle::read_available(
char* start,
char* end )
549 throw std::runtime_error(
550 "Cannot read data on a socket that is not opened" );
552 auto count = ::recv( s_,
553 static_cast< detail::recv_ptr_type >( start ),
558 throw std::runtime_error(
"Unable to read data from socket" );
560 return start + count;
569 #ifdef _NDS_IMPL_ENABLE_CATCH_TESTS_
575 TEST_CASE(
"Can create a socket object",
"[create_basic]" )
577 nds_impl::Socket::FullSocket s;
578 REQUIRE( !s.good( ) );
581 TEST_CASE(
"Can create a socket from a fd number",
"[create_from_fd]" )
583 nds_impl::Socket::FullSocket s{ 4 };
584 REQUIRE( s.good( ) );
585 REQUIRE( s.get( ) == 4 );
587 REQUIRE( !s.good( ) );
588 REQUIRE( s.get( ) == -1 );
591 TEST_CASE(
"Can parse and lookup a name",
"[ext_network]" )
593 nds_impl::Socket::Address addr =
594 nds_impl::Socket::parse_address(
"localhost:10000" );
595 REQUIRE( addr.port == 10000 );
596 REQUIRE( addr.node == htonl( 0x7f000001 ) );
598 addr = nds_impl::Socket::parse_address(
"localhost" );
599 REQUIRE( addr.port == 0 );
600 REQUIRE( addr.node == htonl( 0x7f000001 ) );
602 addr = nds_impl::Socket::parse_address(
":5050" );
603 REQUIRE( addr.port == 5050 );
604 REQUIRE( addr.node == 0 );
607 TEST_CASE(
"Test release",
"[release]" )
609 nds_impl::Socket::FullSocket s;
611 REQUIRE( s.release( ) == nds_impl::Socket::detail::BAD_SOCKET );
612 REQUIRE( s.get( ) == nds_impl::Socket::detail::BAD_SOCKET );
615 REQUIRE( s.get( ) == 4 );
616 REQUIRE( s.release( ) == 4 );
617 REQUIRE( s.get( ) == nds_impl::Socket::detail::BAD_SOCKET );
620 TEST_CASE(
"Test get",
"[get]" )
622 nds_impl::Socket::FullSocket s1{ 5 };
623 nds_impl::Socket::FullSocket s2;
625 REQUIRE( s1.get( ) == 5 );
626 REQUIRE( s2.get( ) == -1 );
628 s2.reset( s1.release( ) );
629 REQUIRE( s1.get( ) == -1 );
630 REQUIRE( s2.get( ) == 5 );
633 REQUIRE( s2.get( ) == -1 );
636 TEST_CASE(
"Test connect with an initialized socket",
"[connect]" )
638 nds_impl::Socket::FullSocket s{ 5 };
639 REQUIRE_THROWS( s.connect(
"localhost:5000" ) );
643 TEST_CASE(
"Test connect with out a port specified",
"[connect_no_port]" )
645 nds_impl::Socket::FullSocket s;
646 REQUIRE_THROWS( s.connect(
"localhost" ) );
649 TEST_CASE(
"Test connect with out a host specified",
"[connect_no_host]" )
651 nds_impl::Socket::FullSocket s;
652 REQUIRE_THROWS( s.connect(
":5000" ) );
655 TEST_CASE(
"Test connect with no address",
"[connect_no_address]" )
657 nds_impl::Socket::FullSocket s1, s2;
658 REQUIRE_THROWS( s1.connect(
"" ) );
659 REQUIRE_THROWS( s2.connect(
":" ) );
662 TEST_CASE(
"Test full connection" )
664 nds_impl::Socket::FullSocket server;
666 unsigned short port = server.listen(
"127.0.0.1" );
667 REQUIRE( port != 0 );
668 std::string target_address =
"127.0.0.1:";
669 target_address += std::to_string( port );
671 std::thread server_thread( [&server]( ) {
672 std::vector< nds_impl::Socket::FullSocket* > sockets;
673 sockets.push_back( &server );
674 auto results = nds_impl::Socket::FullSocket::select( sockets, 5, 0 );
675 if ( results.empty( ) )
677 throw std::runtime_error(
"Client did not connect in time" );
679 auto client = server.accept( );
680 std::string data{
"hi there" };
681 client.write_all( data.data( ), data.data( ) + data.size( ) );
683 nds_impl::Socket::FullSocket client;
684 client.connect( target_address );
685 std::vector< char > buf( 8 );
686 auto end = client.read_available( buf.data( ), buf.data( ) + buf.size( ) );
687 REQUIRE( end != buf.data( ) );
693 server_thread.join( );
696 TEST_CASE(
"Test socket wrapper" )
698 nds_impl::Socket::FullSocket server;
700 unsigned short port = server.listen(
"127.0.0.1" );
701 REQUIRE( port != 0 );
702 std::string target_address =
"127.0.0.1:";
703 target_address += std::to_string( port );
705 std::thread server_thread( [&server]( ) {
706 std::vector< nds_impl::Socket::FullSocket* > sockets;
707 sockets.push_back( &server );
708 auto results = nds_impl::Socket::FullSocket::select( sockets, 5, 0 );
709 if ( results.empty( ) )
711 throw std::runtime_error(
"Client did not connect in time" );
713 auto client = server.accept( );
714 std::string data{
"hi there" };
715 client.write_all( data.data( ), data.data( ) + data.size( ) );
717 nds_impl::Socket::FullSocket client_;
718 client_.connect( target_address );
720 nds_impl::Socket::SocketHandle client{ client_.get( ) };
721 std::vector< char > buf( 8 );
722 auto end = client.read_available( buf.data( ), buf.data( ) + buf.size( ) );
723 REQUIRE( end != buf.data( ) );
729 server_thread.join( );
732 #endif // _NDS_IMPL_ENABLE_CATCH_TESTS_
734 #endif // NDS_SOCKET_HH