From 4eb9caaa391c956398765bd3fd5913039a5dedc6 Mon Sep 17 00:00:00 2001 From: Jonas 'Sortie' Termansen Date: Sat, 25 Feb 2017 17:00:24 +0100 Subject: [PATCH] Fix non-blocking accept4(2) and getting the Unix socket peer address. Rename the internal kernel method from accept to accept4. fixup! Fix non-blocking accept4(2) and getting the unix socket peer address. --- kernel/descriptor.cpp | 8 +- kernel/fs/user.cpp | 29 +++--- kernel/include/sortix/kernel/descriptor.h | 4 +- kernel/include/sortix/kernel/inode.h | 8 +- kernel/include/sortix/kernel/vnode.h | 2 +- kernel/inode.cpp | 4 +- kernel/io.cpp | 10 ++- kernel/net/fs.cpp | 102 ++++++++++------------ kernel/vnode.cpp | 8 +- 9 files changed, 91 insertions(+), 84 deletions(-) diff --git a/kernel/descriptor.cpp b/kernel/descriptor.cpp index ebcbc2da..d4ebd17f 100644 --- a/kernel/descriptor.cpp +++ b/kernel/descriptor.cpp @@ -850,11 +850,15 @@ int Descriptor::poll(ioctx_t* ctx, PollNode* node) return vnode->poll(ctx, node); } -Ref Descriptor::accept(ioctx_t* ctx, uint8_t* addr, size_t* addrlen, int flags) +Ref Descriptor::accept4(ioctx_t* ctx, uint8_t* addr, + size_t* addrlen, int flags) { - Ref retvnode = vnode->accept(ctx, addr, addrlen, flags); + int old_ctx_dflags = ctx->dflags; + ctx->dflags = ContextFlags(old_ctx_dflags, dflags); + Ref retvnode = vnode->accept4(ctx, addr, addrlen, flags); if ( !retvnode ) return Ref(); + ctx->dflags = old_ctx_dflags; return Ref(new Descriptor(retvnode, O_READ | O_WRITE)); } diff --git a/kernel/fs/user.cpp b/kernel/fs/user.cpp index d28b9d0b..5963517f 100644 --- a/kernel/fs/user.cpp +++ b/kernel/fs/user.cpp @@ -160,7 +160,7 @@ public: void Disconnect(); void Unmount(); Channel* Connect(ioctx_t* ctx); - Channel* Accept(); + Channel* Accept(ioctx_t* ctx); Ref BootstrapNode(ino_t ino, mode_t type); Ref OpenNode(ino_t ino, mode_t type); @@ -181,8 +181,8 @@ class ServerNode : public AbstractInode public: ServerNode(Ref server); virtual ~ServerNode(); - virtual Ref accept(ioctx_t* ctx, uint8_t* addr, size_t* addrlen, - int flags); + virtual Ref accept4(ioctx_t* ctx, uint8_t* addr, size_t* addrlen, + int flags); private: Ref server; @@ -242,8 +242,8 @@ public: virtual int poll(ioctx_t* ctx, PollNode* node); virtual int rename_here(ioctx_t* ctx, Ref from, const char* oldname, const char* newname); - virtual Ref accept(ioctx_t* ctx, uint8_t* addr, size_t* addrlen, - int flags); + virtual Ref accept4(ioctx_t* ctx, uint8_t* addr, size_t* addrlen, + int flags); virtual int bind(ioctx_t* ctx, const uint8_t* addr, size_t addrlen); virtual int connect(ioctx_t* ctx, const uint8_t* addr, size_t addrlen); virtual int listen(ioctx_t* ctx, int backlog); @@ -594,13 +594,17 @@ Channel* Server::Connect(ioctx_t* ctx) return channel; } -Channel* Server::Accept() +Channel* Server::Accept(ioctx_t* ctx) { ScopedLock lock(&connect_lock); listener_system_tid = CurrentThread()->system_tid; while ( !connecting && !unmounted ) + { + if ( ctx->dflags & O_NONBLOCK ) + return errno = EWOULDBLOCK, (Channel*) NULL; if ( !kthread_cond_wait_signal(&connecting_cond, &connect_lock) ) return errno = EINTR, (Channel*) NULL; + } if ( unmounted ) return errno = ECONNRESET, (Channel*) NULL; Channel* result = connecting; @@ -638,18 +642,19 @@ ServerNode::~ServerNode() server->Disconnect(); } -Ref ServerNode::accept(ioctx_t* ctx, uint8_t* addr, size_t* addrlen, - int flags) +Ref ServerNode::accept4(ioctx_t* ctx, uint8_t* addr, size_t* addrlen, + int flags) { (void) addr; - (void) flags; + if ( flags & ~(0) ) + return errno = EINVAL, Ref(NULL); size_t out_addrlen = 0; if ( addrlen && !ctx->copy_to_dest(addrlen, &out_addrlen, sizeof(out_addrlen)) ) return Ref(NULL); Ref node(new ChannelNode); if ( !node ) return Ref(NULL); - Channel* channel = server->Accept(); + Channel* channel = server->Accept(ctx); if ( !channel ) return Ref(NULL); node->Construct(channel); @@ -1462,8 +1467,8 @@ int Unode::rename_here(ioctx_t* ctx, Ref from, const char* oldname, return ret; } -Ref Unode::accept(ioctx_t* /*ctx*/, uint8_t* /*addr*/, - size_t* /*addrlen*/, int /*flags*/) +Ref Unode::accept4(ioctx_t* /*ctx*/, uint8_t* /*addr*/, + size_t* /*addrlen*/, int /*flags*/) { return errno = ENOTSOCK, Ref(); } diff --git a/kernel/include/sortix/kernel/descriptor.h b/kernel/include/sortix/kernel/descriptor.h index 91baaaf4..b9f2bdb3 100644 --- a/kernel/include/sortix/kernel/descriptor.h +++ b/kernel/include/sortix/kernel/descriptor.h @@ -94,8 +94,8 @@ public: int poll(ioctx_t* ctx, PollNode* node); int rename_here(ioctx_t* ctx, Ref from, const char* oldpath, const char* newpath); - Ref accept(ioctx_t* ctx, uint8_t* addr, size_t* addrlen, - int flags); + Ref accept4(ioctx_t* ctx, uint8_t* addr, size_t* addrlen, + int flags); int bind(ioctx_t* ctx, const uint8_t* addr, size_t addrlen); int connect(ioctx_t* ctx, const uint8_t* addr, size_t addrlen); int listen(ioctx_t* ctx, int backlog); diff --git a/kernel/include/sortix/kernel/inode.h b/kernel/include/sortix/kernel/inode.h index 48a900f7..c0f80be0 100644 --- a/kernel/include/sortix/kernel/inode.h +++ b/kernel/include/sortix/kernel/inode.h @@ -104,8 +104,8 @@ public: virtual int poll(ioctx_t* ctx, PollNode* node) = 0; virtual int rename_here(ioctx_t* ctx, Ref from, const char* oldname, const char* newname) = 0; - virtual Ref accept(ioctx_t* ctx, uint8_t* addr, size_t* addrlen, - int flags) = 0; + virtual Ref accept4(ioctx_t* ctx, uint8_t* addr, size_t* addrlen, + int flags) = 0; virtual int bind(ioctx_t* ctx, const uint8_t* addr, size_t addrlen) = 0; virtual int connect(ioctx_t* ctx, const uint8_t* addr, size_t addrlen) = 0; virtual int listen(ioctx_t* ctx, int backlog) = 0; @@ -210,8 +210,8 @@ public: virtual int poll(ioctx_t* ctx, PollNode* node); virtual int rename_here(ioctx_t* ctx, Ref from, const char* oldname, const char* newname); - virtual Ref accept(ioctx_t* ctx, uint8_t* addr, size_t* addrlen, - int flags); + virtual Ref accept4(ioctx_t* ctx, uint8_t* addr, size_t* addrlen, + int flags); virtual int bind(ioctx_t* ctx, const uint8_t* addr, size_t addrlen); virtual int connect(ioctx_t* ctx, const uint8_t* addr, size_t addrlen); virtual int listen(ioctx_t* ctx, int backlog); diff --git a/kernel/include/sortix/kernel/vnode.h b/kernel/include/sortix/kernel/vnode.h index 996f0305..da768798 100644 --- a/kernel/include/sortix/kernel/vnode.h +++ b/kernel/include/sortix/kernel/vnode.h @@ -93,7 +93,7 @@ public: int poll(ioctx_t* ctx, PollNode* node); int rename_here(ioctx_t* ctx, Ref from, const char* oldname, const char* newname); - Ref accept(ioctx_t* ctx, uint8_t* addr, size_t* addrlen, int flags); + Ref accept4(ioctx_t* ctx, uint8_t* addr, size_t* addrlen, int flags); int bind(ioctx_t* ctx, const uint8_t* addr, size_t addrlen); int connect(ioctx_t* ctx, const uint8_t* addr, size_t addrlen); int listen(ioctx_t* ctx, int backlog); diff --git a/kernel/inode.cpp b/kernel/inode.cpp index b078a773..93fde2c7 100644 --- a/kernel/inode.cpp +++ b/kernel/inode.cpp @@ -512,8 +512,8 @@ int AbstractInode::rename_here(ioctx_t* /*ctx*/, Ref /*from*/, return errno = ENOTDIR, -1; } -Ref AbstractInode::accept(ioctx_t* /*ctx*/, uint8_t* /*addr*/, - size_t* /*addrlen*/, int /*flags*/) +Ref AbstractInode::accept4(ioctx_t* /*ctx*/, uint8_t* /*addr*/, + size_t* /*addrlen*/, int /*flags*/) { return errno = ENOTSOCK, Ref(); } diff --git a/kernel/io.cpp b/kernel/io.cpp index bef94f11..9ddc4006 100644 --- a/kernel/io.cpp +++ b/kernel/io.cpp @@ -731,13 +731,15 @@ int sys_accept4(int fd, void* addr, size_t* addrlen, int flags) int fdflags = 0; if ( flags & SOCK_CLOEXEC ) fdflags |= FD_CLOEXEC; if ( flags & SOCK_CLOFORK ) fdflags |= FD_CLOFORK; - flags &= ~(SOCK_CLOEXEC | SOCK_CLOFORK); + int descflags = 0; + if ( flags & SOCK_NONBLOCK ) descflags |= O_NONBLOCK; + flags &= ~(SOCK_CLOEXEC | SOCK_CLOFORK | SOCK_NONBLOCK); ioctx_t ctx; SetupUserIOCtx(&ctx); - Ref conn = desc->accept(&ctx, (uint8_t*) addr, addrlen, flags); + Ref conn = desc->accept4(&ctx, (uint8_t*) addr, addrlen, flags); if ( !conn ) return -1; - if ( flags & SOCK_NONBLOCK ) - conn->SetFlags(conn->GetFlags() | O_NONBLOCK); + if ( descflags ) + conn->SetFlags(conn->GetFlags() | descflags); return CurrentProcess()->GetDTable()->Allocate(conn, fdflags); } diff --git a/kernel/net/fs.cpp b/kernel/net/fs.cpp index bb7db55e..6d0c29f3 100644 --- a/kernel/net/fs.cpp +++ b/kernel/net/fs.cpp @@ -24,6 +24,7 @@ #include #include #include +#include #include #include @@ -82,8 +83,8 @@ class StreamSocket : public AbstractInode public: StreamSocket(uid_t owner, gid_t group, mode_t mode, Ref manager); virtual ~StreamSocket(); - virtual Ref accept(ioctx_t* ctx, uint8_t* addr, size_t* addrsize, - int flags); + virtual Ref accept4(ioctx_t* ctx, uint8_t* addr, size_t* addrsize, + int flags); virtual int bind(ioctx_t* ctx, const uint8_t* addr, size_t addrsize); virtual int connect(ioctx_t* ctx, const uint8_t* addr, size_t addrsize); virtual int listen(ioctx_t* ctx, int backlog); @@ -116,6 +117,7 @@ public: /* For use by Manager. */ StreamSocket* first_pending; StreamSocket* last_pending; struct sockaddr_un* bound_address; + size_t bound_address_size; bool is_listening; bool is_connected; bool is_refused; @@ -167,6 +169,7 @@ StreamSocket::StreamSocket(uid_t owner, gid_t group, mode_t mode, this->first_pending = NULL; this->last_pending = NULL; this->bound_address = NULL; + this->bound_address_size = 0; this->is_listening = false; this->is_connected = false; this->is_refused = false; @@ -181,11 +184,11 @@ StreamSocket::~StreamSocket() { if ( is_listening ) manager->Unlisten(this); - delete[] bound_address; + free(bound_address); } -Ref StreamSocket::accept(ioctx_t* ctx, uint8_t* addr, size_t* addrsize, - int flags) +Ref StreamSocket::accept4(ioctx_t* ctx, uint8_t* addr, size_t* addrsize, + int flags) { ScopedLock lock(&socket_lock); if ( !is_listening ) @@ -198,33 +201,25 @@ int StreamSocket::do_bind(ioctx_t* ctx, const uint8_t* addr, size_t addrsize) if ( is_connected || is_listening || bound_address ) return errno = EINVAL, -1; size_t path_offset = offsetof(struct sockaddr_un, sun_path); - size_t path_len = (path_offset - addrsize) / sizeof(char); if ( addrsize < path_offset ) return errno = EINVAL, -1; - uint8_t* buffer = new uint8_t[addrsize]; - if ( !buffer ) + size_t path_len = path_offset - addrsize; + struct sockaddr_un* address = (struct sockaddr_un*) malloc(addrsize); + if ( !address ) return -1; - if ( ctx->copy_from_src(buffer, addr, addrsize) ) - { - struct sockaddr_un* address = (struct sockaddr_un*) buffer; - if ( address->sun_family == AF_UNIX ) - { - bool found_nul = false; - for ( size_t i = 0; !found_nul && i < path_len; i++ ) - if ( address->sun_path[i] == '\0' ) - found_nul = true; - if ( found_nul ) - { - bound_address = address; - return 0; - } - errno = EINVAL; - } - else - errno = EAFNOSUPPORT; - } - delete[] buffer; - return -1; + if ( !ctx->copy_from_src(address, addr, addrsize) ) + return free(address), -1; + if ( address->sun_family != AF_UNIX ) + return free(address), errno = EAFNOSUPPORT, -1; + bool found_nul = false; + for ( size_t i = 0; !found_nul && i < path_len; i++ ) + if ( address->sun_path[i] == '\0' ) + found_nul = true; + if ( !found_nul ) + return free(address), errno = EINVAL, -1; + bound_address = address; + bound_address_size = addrsize; + return 0; } int StreamSocket::bind(ioctx_t* ctx, const uint8_t* addr, size_t addrsize) @@ -465,40 +460,43 @@ int Manager::AcceptPoll(StreamSocket* socket, ioctx_t* /*ctx*/, PollNode* node) } Ref Manager::Accept(StreamSocket* socket, ioctx_t* ctx, - uint8_t* addr, size_t* addrsize, int /*flags*/) + uint8_t* addr, size_t* addrsize, int flags) { + if ( flags & ~(0) ) + return errno = EINVAL, Ref(NULL); + ScopedLock lock(&manager_lock); - // TODO: Support non-blocking accept! while ( !socket->first_pending ) + { + if ( (ctx->dflags & O_NONBLOCK) || (flags & SOCK_NONBLOCK) ) + return errno = EWOULDBLOCK, Ref(NULL); if ( !kthread_cond_wait_signal(&socket->pending_cond, &manager_lock) ) return errno = EINTR, Ref(NULL); - - StreamSocket* client = socket->first_pending; - - struct sockaddr_un* client_addr = client->bound_address; - size_t client_addr_size = offsetof(struct sockaddr_un, sun_path) + - (strlen(client_addr->sun_path)+1) * sizeof(char); - - if ( addr ) - { - size_t caller_addrsize; - if ( !ctx->copy_from_src(&caller_addrsize, addrsize, sizeof(caller_addrsize)) ) - return Ref(NULL); - if ( caller_addrsize < client_addr_size ) - return errno = ERANGE, Ref(NULL); - if ( !ctx->copy_from_src(addrsize, &client_addr_size, sizeof(client_addr_size)) ) - return Ref(NULL); - if ( !ctx->copy_to_dest(addr, client_addr, client_addr_size) ) - return Ref(NULL); } - // TODO: Give the caller the address of the remote! + struct sockaddr_un* bound_address = socket->bound_address; + size_t bound_address_size = socket->bound_address_size; + if ( addr ) + { + size_t used_addrsize; + if ( !ctx->copy_from_src(&used_addrsize, addrsize, + sizeof(used_addrsize)) ) + return Ref(NULL); + if ( bound_address_size < used_addrsize ) + used_addrsize = bound_address_size; + if ( !ctx->copy_to_dest(addr, bound_address, bound_address_size) ) + return Ref(NULL); + if ( !ctx->copy_to_dest(addrsize, &used_addrsize, + sizeof(used_addrsize)) ) + return Ref(NULL); + } Ref server(new StreamSocket(0, 0, 0666, Ref(this))); if ( !server ) return Ref(NULL); + StreamSocket* client = socket->first_pending; QueuePop(&socket->first_pending, &socket->last_pending); if ( !client->outgoing.Connect(&server->incoming) ) @@ -513,10 +511,6 @@ Ref Manager::Accept(StreamSocket* socket, ioctx_t* ctx, client->is_connected = true; server->is_connected = true; - // TODO: Should the server socket inherit the address of the listening - // socket or perhaps the one of the client's source/destination, or - // nothing at all? - kthread_cond_signal(&client->accepted_cond); return server; diff --git a/kernel/vnode.cpp b/kernel/vnode.cpp index f1f45ff3..2a59fcc1 100644 --- a/kernel/vnode.cpp +++ b/kernel/vnode.cpp @@ -391,12 +391,14 @@ int Vnode::poll(ioctx_t* ctx, PollNode* node) return inode->poll(ctx, node); } -Ref Vnode::accept(ioctx_t* ctx, uint8_t* addr, size_t* addrlen, int flags) +Ref Vnode::accept4(ioctx_t* ctx, uint8_t* addr, size_t* addrlen, + int flags) { - Ref retinode = inode->accept(ctx, addr, addrlen, flags); + Ref retinode = inode->accept4(ctx, addr, addrlen, flags); if ( !retinode ) return Ref(); - return Ref(new Vnode(retinode, Ref(), retinode->ino, retinode->dev)); + return Ref(new Vnode(retinode, Ref(), retinode->ino, + retinode->dev)); } int Vnode::bind(ioctx_t* ctx, const uint8_t* addr, size_t addrlen)