Skip to content

Commit 136ff55

Browse files
yihuaxuzzhang37
authored andcommitted
UCT/GAUDI: Free the fds when called md_close (openucx#11017)
*UCT/GAUDI: Free the fds when called md_close
1 parent 94a58ed commit 136ff55

File tree

4 files changed

+35
-11
lines changed

4 files changed

+35
-11
lines changed

src/uct/gaudi/base/gaudi_base.c

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,16 +20,37 @@
2020
#include <hlthunk.h>
2121
#include <synapse_api.h>
2222

23-
int uct_gaudi_base_get_fd(int device_id) {
23+
int uct_gaudi_base_get_fd(int device_id, bool *fd_created) {
2424
synDeviceInfo deviceInfo;
2525

2626
if (synDeviceGetInfo(-1, &deviceInfo) != synSuccess) {
27-
return hlthunk_open_by_module_id(device_id);
27+
int fd = hlthunk_open_by_module_id(device_id);
28+
if (fd <0) {
29+
ucs_info("Failed to get device fd via hlthunk_open_by_module_id, id %d", device_id);
30+
fd = hlthunk_open(HLTHUNK_DEVICE_DONT_CARE, NULL);
31+
}
32+
33+
if (fd >=0 && fd_created != NULL) {
34+
*fd_created = true;
35+
}
36+
return fd;
2837
}
2938

3039
return deviceInfo.fd;
3140
}
3241

42+
void uct_gaudi_base_close_fd(int fd, bool fd_created) {
43+
if (fd_created && fd >= 0) {
44+
hlthunk_close(fd);
45+
}
46+
}
47+
48+
void uct_gaudi_base_close_dmabuf_fd(int fd) {
49+
if (fd >= 0) {
50+
close(fd);
51+
}
52+
}
53+
3354
ucs_status_t uct_gaudi_base_get_sysdev(int fd, ucs_sys_device_t* sys_dev) {
3455
ucs_status_t status;
3556
char pci_bus_id[13];

src/uct/gaudi/base/gaudi_base.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,14 @@
66
#ifndef GAUDI_BASE_H_
77
#define GAUDI_BASE_H_
88

9+
#include <stdbool.h>
910
#include <uct/base/uct_iface.h>
1011
#include <uct/base/uct_md.h>
1112
#include "scal.h"
1213

13-
int uct_gaudi_base_get_fd(int device_id);
14+
int uct_gaudi_base_get_fd(int device_id, bool *fd_created);
15+
void uct_gaudi_base_close_fd(int fd, bool fd_created);
16+
void uct_gaudi_base_close_dmabuf_fd(int fd);
1417
ucs_status_t uct_gaudi_base_get_sysdev(int fd, ucs_sys_device_t* sys_dev);
1518
ucs_status_t uct_gaudi_base_get_info(int fd, uint64_t *device_base_allocated_address, uint64_t *device_base_address,
1619
uint64_t *totalSize, int *dmabuf_fd);

src/uct/gaudi/gaudi_gdr/gaudi_gdr_md.c

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -42,12 +42,8 @@ static ucs_status_t uct_gaudi_md_query(uct_md_h md, uct_md_attr_v2_t *attr)
4242
static void uct_gaudi_md_close(uct_md_h uct_md)
4343
{
4444
uct_gaudi_md_t *md = ucs_derived_of(uct_md, uct_gaudi_md_t);
45-
if (md->dmabuf_fd >= 0) {
46-
close(md->dmabuf_fd);
47-
}
48-
if (md->fd >= 0) {
49-
close(md->fd);
50-
}
45+
uct_gaudi_base_close_dmabuf_fd(md->dmabuf_fd);
46+
uct_gaudi_base_close_fd(md->fd, md->fd_created);
5147
ucs_free(md);
5248
}
5349

@@ -156,6 +152,7 @@ uct_gaudi_md_open(uct_component_h component, const char *md_name,
156152
uct_gaudi_md_config_t);
157153
uct_gaudi_md_t *md;
158154
ucs_status_t status;
155+
bool fd_created = false;
159156
int fd;
160157

161158
md = ucs_malloc(sizeof(uct_gaudi_md_t), "uct_gaudi_md_t");
@@ -164,7 +161,7 @@ uct_gaudi_md_open(uct_component_h component, const char *md_name,
164161
return UCS_ERR_NO_MEMORY;
165162
}
166163

167-
fd = uct_gaudi_base_get_fd(config->device_id);
164+
fd = uct_gaudi_base_get_fd(config->device_id, &fd_created);
168165
if (fd <0) {
169166
ucs_error("Failed to get device fd");
170167
status = UCS_ERR_NO_DEVICE;
@@ -187,7 +184,8 @@ uct_gaudi_md_open(uct_component_h component, const char *md_name,
187184
goto err_close_dmabuf;
188185
}
189186

190-
md->fd = dup(fd);
187+
md->fd = fd;
188+
md->fd_created = fd_created;
191189
md->super.ops = &md_ops;
192190
md->super.component = &uct_gaudi_gdr_component;
193191

src/uct/gaudi/gaudi_gdr/gaudi_gdr_md.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
#ifndef GAUDI_MD_H
77
#define GAUDI_MD_H
88

9+
#include <stdbool.h>
910
#include <uct/base/uct_md.h>
1011
#include <ucs/config/types.h>
1112

@@ -14,6 +15,7 @@ extern uct_component_t uct_gaudi_gdr_component;
1415
typedef struct uct_gaudi_md {
1516
uct_md_t super;
1617
int fd;
18+
bool fd_created;
1719
uint64_t device_base_allocated_address;
1820
uint64_t device_base_address;
1921
uint64_t totalSize;

0 commit comments

Comments
 (0)