Skip to content

Commit

Permalink
Add check versions of socket close
Browse files Browse the repository at this point in the history
  • Loading branch information
gordicaleksa committed Jun 25, 2024
1 parent 3a986bc commit 90aed62
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 15 deletions.
26 changes: 26 additions & 0 deletions llmc/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,32 @@ extern inline void fclose_check(FILE *fp, const char *file, int line) {

#define fcloseCheck(fp) fclose_check(fp, __FILE__, __LINE__)

extern inline void sclose_check(int sockfd, const char *file, int line) {
if (close(sockfd) != 0) {
fprintf(stderr, "Error: Failed to close socket at %s:%d\n", file, line);
fprintf(stderr, "Error details:\n");
fprintf(stderr, " File: %s\n", file);
fprintf(stderr, " Line: %d\n", line);
exit(EXIT_FAILURE);
}
}

#define scloseCheck(sockfd) sclose_check(sockfd, __FILE__, __LINE__)

#ifdef _WIN32
extern inline void closesocket_check(int sockfd, const char *file, int line) {
if (closesocket(sockfd) != 0) {
fprintf(stderr, "Error: Failed to close socket at %s:%d\n", file, line);
fprintf(stderr, "Error details:\n");
fprintf(stderr, " File: %s\n", file);
fprintf(stderr, " Line: %d\n", line);
exit(EXIT_FAILURE);
}
}

#define closesocketCheck(sockfd) closesocket_check(sockfd, __FILE__, __LINE__)
#endif

extern inline void fseek_check(FILE *fp, long off, int whence, const char *file, int line) {
if (fseek(fp, off, whence) != 0) {
fprintf(stderr, "Error: Failed to seek in file at %s:%d\n", file, line);
Expand Down
33 changes: 18 additions & 15 deletions llmc/zero.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@ Utilities for ZeRO sharding
#endif
#endif

// defines: fcloseCheck, fwriteCheck, scloseCheck, sclosesocketCheckCheck
#include "utils.h"

// ----------------------------------------------------------------------------
// Multi-GPU related
#ifdef MULTI_GPU
Expand Down Expand Up @@ -91,7 +94,7 @@ void send_nccl_id_to_clients_windows(ncclUniqueId *nccl_id, SOCKET client_socket
WSACleanup();
exit(EXIT_FAILURE);
}
closesocket(client_sockets[i]);
closesocketCheck(client_sockets[i]);
}
}
#else
Expand All @@ -101,7 +104,7 @@ void send_nccl_id_to_clients(ncclUniqueId *nccl_id, int client_sockets[], int nu
printf("Failed to send nccl_id");
exit(EXIT_FAILURE);
}
close(client_sockets[i]);
scloseCheck(client_sockets[i]);
}
}
#endif
Expand Down Expand Up @@ -143,15 +146,15 @@ ncclUniqueId get_nccl_id_via_tcp_windows(MultiGpuConfig* result, const char* ser
// Step 3) bind the socket to the address and port
if (bind(server_socket, (struct sockaddr *)&address, sizeof(address)) == SOCKET_ERROR) {
printf("Bind failed");
closesocket(server_socket);
closesocketCheck(server_socket);
WSACleanup();
exit(EXIT_FAILURE);
}

// Step 4) MAX_CLIENTS specifies the maximum number of clients that can be queued for this server
if (listen(server_socket, MAX_CLIENTS) == SOCKET_ERROR) {
printf("Listen failed");
closesocket(server_socket);
closesocketCheck(server_socket);
WSACleanup();
exit(EXIT_FAILURE);
}
Expand All @@ -161,7 +164,7 @@ ncclUniqueId get_nccl_id_via_tcp_windows(MultiGpuConfig* result, const char* ser
while (num_clients < MAX_CLIENTS) {
if ((new_socket = accept(server_socket, (struct sockaddr *)&address, &addrlen)) == INVALID_SOCKET) {
printf("Accept failed");
closesocket(server_socket);
closesocketCheck(server_socket);
WSACleanup();
exit(EXIT_FAILURE);
}
Expand All @@ -173,7 +176,7 @@ ncclUniqueId get_nccl_id_via_tcp_windows(MultiGpuConfig* result, const char* ser
send_nccl_id_to_clients_windows(&nccl_id, client_sockets, num_clients);
printf("NCCL ID sent to all clients\n");

closesocket(server_socket);
closesocketCheck(server_socket);
} else {
int num_connection_attempts = 5;
int time_to_sleep = 2;
Expand All @@ -192,7 +195,7 @@ ncclUniqueId get_nccl_id_via_tcp_windows(MultiGpuConfig* result, const char* ser
serv_addr.sin_port = htons(SERVER_PORT);
if (inet_pton(AF_INET, server_ip, &serv_addr.sin_addr) <= 0) {
printf("Invalid address or address not supported");
closesocket(client_socket);
closesocketCheck(client_socket);
WSACleanup();
exit(EXIT_FAILURE);
}
Expand All @@ -202,7 +205,7 @@ ncclUniqueId get_nccl_id_via_tcp_windows(MultiGpuConfig* result, const char* ser
printf("%d Connection failed, retrying in %d seconds\n", result->process_rank, time_to_sleep);
if (--num_connection_attempts == 0) {
printf("Failed to connect to the server\n");
closesocket(client_socket);
closesocketCheck(client_socket);
WSACleanup();
exit(EXIT_FAILURE);
}
Expand All @@ -212,13 +215,13 @@ ncclUniqueId get_nccl_id_via_tcp_windows(MultiGpuConfig* result, const char* ser
// Step 4) receive the NCCL ID from the server
if (recv(client_socket, (char *)&nccl_id, sizeof(nccl_id), 0) <= 0) {
printf("Failed to receive nccl_id");
closesocket(client_socket);
closesocketCheck(client_socket);
WSACleanup();
exit(EXIT_FAILURE);
}

printf("Received NCCL ID\n");
closesocket(client_socket);
closesocketCheck(client_socket);
}

WSACleanup();
Expand Down Expand Up @@ -287,7 +290,7 @@ ncclUniqueId get_nccl_id_via_tcp(MultiGpuConfig* result, const char* server_ip)
send_nccl_id_to_clients(&nccl_id, client_sockets, num_clients);
printf("NCCL ID sent to all clients\n");

close(server_socket);
scloseCheck(server_socket);
} else {
int num_connection_attempts = 5;
int time_to_sleep = 2;
Expand Down Expand Up @@ -325,7 +328,7 @@ ncclUniqueId get_nccl_id_via_tcp(MultiGpuConfig* result, const char* server_ip)
}

printf("Received NCCL ID\n");
close(client_socket);
scloseCheck(client_socket);
}

return nccl_id;
Expand All @@ -348,8 +351,8 @@ ncclUniqueId get_nccl_id_via_fs(MultiGpuConfig* result, char* fs_path) {
ncclCheck(ncclGetUniqueId(&nccl_id));
idFile = fopen(filename, "wb");
assert(idFile != NULL);
fwrite(&nccl_id, sizeof(nccl_id), 1, idFile);
fclose(idFile);
fwriteCheck(&nccl_id, sizeof(nccl_id), 1, idFile);
fcloseCheck(idFile);
} else {
// Other ranks wait until the file is available and read the unique ID
do {
Expand All @@ -358,7 +361,7 @@ ncclUniqueId get_nccl_id_via_fs(MultiGpuConfig* result, char* fs_path) {
if (idFile != NULL) break;
} while (idFile == NULL);
freadCheck(&nccl_id, sizeof(nccl_id), 1, idFile);
fclose(idFile);
fcloseCheck(idFile);
}

return nccl_id;
Expand Down

0 comments on commit 90aed62

Please sign in to comment.