Skip to content

Commit 6fbeaed

Browse files
committed
feat(Resources): introduce fabric in SSHCE
1 parent 5f268bd commit 6fbeaed

File tree

3 files changed

+320
-470
lines changed

3 files changed

+320
-470
lines changed

environment.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ dependencies:
2020
- elasticsearch <7.14
2121
- elasticsearch-dsl
2222
- opensearch-py
23+
- fabric
2324
- fts3
2425
- gitpython >=2.1.0
2526
- m2crypto >=0.38.0

src/DIRAC/Resources/Computing/SSHBatchComputingElement.py

Lines changed: 80 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
""" SSH (Virtual) Computing Element: For a given list of ip/cores pair it will send jobs
1+
""" SSH (Virtual) Batch Computing Element: For a given list of ip/cores pair it will send jobs
22
directly through ssh
33
"""
44

@@ -13,64 +13,78 @@
1313

1414

1515
class SSHBatchComputingElement(SSHComputingElement):
16-
#############################################################################
1716
def __init__(self, ceUniqueID):
1817
"""Standard constructor."""
1918
super().__init__(ceUniqueID)
2019

21-
self.ceType = "SSHBatch"
22-
self.sshHost = []
20+
self.connections = {}
2321
self.execution = "SSHBATCH"
2422

2523
def _reset(self):
2624
"""Process CE parameters and make necessary adjustments"""
25+
# Get the Batch System instance
2726
result = self._getBatchSystem()
2827
if not result["OK"]:
2928
return result
29+
30+
# Get the location of the remote directories
3031
self._getBatchSystemDirectoryLocations()
3132

32-
self.user = self.ceParameters["SSHUser"]
33+
# Get the SSH parameters
34+
self.timeout = self.ceParameters.get("Timeout", self.timeout)
35+
self.user = self.ceParameters.get("SSHUser", self.user)
36+
port = self.ceParameters.get("SSHPort", None)
37+
password = self.ceParameters.get("SSHPassword", None)
38+
key = self.ceParameters.get("SSHKey", None)
39+
tunnel = self.ceParameters.get("SSHTunnel", None)
40+
41+
# Get submission parameters
42+
self.submitOptions = self.ceParameters.get("SubmitOptions", self.submitOptions)
43+
self.preamble = self.ceParameters.get("Preamble", self.preamble)
44+
self.account = self.ceParameters.get("Account", self.account)
3345
self.queue = self.ceParameters["Queue"]
3446
self.log.info("Using queue: ", self.queue)
3547

36-
self.submitOptions = self.ceParameters.get("SubmitOptions", "")
37-
self.preamble = self.ceParameters.get("Preamble", "")
38-
self.account = self.ceParameters.get("Account", "")
39-
40-
# Prepare all the hosts
41-
for hPar in self.ceParameters["SSHHost"].strip().split(","):
42-
host = hPar.strip().split("/")[0]
43-
result = self._prepareRemoteHost(host=host)
44-
if result["OK"]:
45-
self.log.info(f"Host {host} registered for usage")
46-
self.sshHost.append(hPar.strip())
48+
# Get output and error templates
49+
self.outputTemplate = self.ceParameters.get("OutputTemplate", self.outputTemplate)
50+
self.errorTemplate = self.ceParameters.get("ErrorTemplate", self.errorTemplate)
51+
52+
# Prepare the remote hosts
53+
for host in self.ceParameters.get("SSHHost", "").strip().split(","):
54+
hostDetails = host.strip().split("/")
55+
if len(hostDetails) > 1:
56+
hostname = hostDetails[0]
57+
maxJobs = int(hostDetails[1])
4758
else:
48-
self.log.error("Failed to initialize host", host)
59+
hostname = hostDetails[0]
60+
maxJobs = self.ceParameters.get("MaxTotalJobs", 0)
61+
62+
connection = self._getConnection(hostname, self.user, port, password, key, tunnel)
63+
64+
result = self._prepareRemoteHost(connection)
65+
if not result["OK"]:
4966
return result
5067

68+
self.connections[hostname] = {"connection": connection, "maxJobs": maxJobs}
69+
self.log.info(f"Host {hostname} registered for usage")
70+
5171
return S_OK()
5272

5373
#############################################################################
74+
5475
def submitJob(self, executableFile, proxy, numberOfJobs=1):
5576
"""Method to submit job"""
56-
5777
# Choose eligible hosts, rank them by the number of available slots
5878
rankHosts = {}
5979
maxSlots = 0
60-
for host in self.sshHost:
61-
thost = host.split("/")
62-
hostName = thost[0]
63-
maxHostJobs = 1
64-
if len(thost) > 1:
65-
maxHostJobs = int(thost[1])
66-
67-
result = self._getHostStatus(hostName)
80+
for _, details in self.connections.items():
81+
result = self._getHostStatus(details["connection"])
6882
if not result["OK"]:
6983
continue
70-
slots = maxHostJobs - result["Value"]["Running"]
84+
slots = details["maxJobs"] - result["Value"]["Running"]
7185
if slots > 0:
7286
rankHosts.setdefault(slots, [])
73-
rankHosts[slots].append(hostName)
87+
rankHosts[slots].append(details["connection"])
7488
if slots > maxSlots:
7589
maxSlots = slots
7690

@@ -96,18 +110,28 @@ def submitJob(self, executableFile, proxy, numberOfJobs=1):
96110
restJobs = numberOfJobs
97111
submittedJobs = []
98112
stampDict = {}
113+
batchSystemName = self.batchSystem.__class__.__name__.lower()
114+
99115
for slots in range(maxSlots, 0, -1):
100116
if slots not in rankHosts:
101117
continue
102-
for host in rankHosts[slots]:
103-
result = self._submitJobToHost(submitFile, min(slots, restJobs), host)
118+
for connection in rankHosts[slots]:
119+
result = self._submitJobToHost(connection, submitFile, min(slots, restJobs))
104120
if not result["OK"]:
105121
continue
106122

107-
nJobs = len(result["Value"])
123+
batchIDs, jobStamps = result["Value"]
124+
125+
nJobs = len(batchIDs)
108126
if nJobs > 0:
109-
submittedJobs.extend(result["Value"])
110-
stampDict.update(result.get("PilotStampDict", {}))
127+
jobIDs = [
128+
f"{self.ceType.lower()}{batchSystemName}://{self.ceName}/{connection.host}/{_id}"
129+
for _id in batchIDs
130+
]
131+
submittedJobs.extend(jobIDs)
132+
for iJob, jobID in enumerate(jobIDs):
133+
stampDict[jobID] = jobStamps[iJob]
134+
111135
restJobs = restJobs - nJobs
112136
if restJobs <= 0:
113137
break
@@ -121,6 +145,8 @@ def submitJob(self, executableFile, proxy, numberOfJobs=1):
121145
result["PilotStampDict"] = stampDict
122146
return result
123147

148+
#############################################################################
149+
124150
def killJob(self, jobIDs):
125151
"""Kill specified jobs"""
126152
jobIDList = list(jobIDs)
@@ -136,7 +162,7 @@ def killJob(self, jobIDs):
136162

137163
failed = []
138164
for host, jobIDList in hostDict.items():
139-
result = self._killJobOnHost(jobIDList, host)
165+
result = self._killJobOnHost(self.connections[host]["connection"], jobIDList)
140166
if not result["OK"]:
141167
failed.extend(jobIDList)
142168
message = result["Message"]
@@ -149,16 +175,17 @@ def killJob(self, jobIDs):
149175

150176
return result
151177

178+
#############################################################################
179+
152180
def getCEStatus(self):
153181
"""Method to return information on running and pending jobs."""
154182
result = S_OK()
155183
result["SubmittedJobs"] = self.submittedJobs
156184
result["RunningJobs"] = 0
157185
result["WaitingJobs"] = 0
158186

159-
for host in self.sshHost:
160-
thost = host.split("/")
161-
resultHost = self._getHostStatus(thost[0])
187+
for _, details in self.connections:
188+
resultHost = self._getHostStatus(details["connection"])
162189
if resultHost["OK"]:
163190
result["RunningJobs"] += resultHost["Value"]["Running"]
164191

@@ -167,6 +194,8 @@ def getCEStatus(self):
167194

168195
return result
169196

197+
#############################################################################
198+
170199
def getJobStatus(self, jobIDList):
171200
"""Get status of the jobs in the given list"""
172201
hostDict = {}
@@ -178,7 +207,7 @@ def getJobStatus(self, jobIDList):
178207
resultDict = {}
179208
failed = []
180209
for host, jobIDList in hostDict.items():
181-
result = self._getJobStatusOnHost(jobIDList, host)
210+
result = self._getJobStatusOnHost(self.connections[host]["connection"], jobIDList)
182211
if not result["OK"]:
183212
failed.extend(jobIDList)
184213
continue
@@ -189,3 +218,16 @@ def getJobStatus(self, jobIDList):
189218
resultDict[job] = PilotStatus.UNKNOWN
190219

191220
return S_OK(resultDict)
221+
222+
#############################################################################
223+
224+
def getJobOutput(self, jobID, localDir=None):
225+
"""Get the specified job standard output and error files. If the localDir is provided,
226+
the output is returned as file in this directory. Otherwise, the output is returned
227+
as strings.
228+
"""
229+
self.log.verbose("Getting output for jobID", jobID)
230+
231+
# host can be retrieved from the path of the jobID
232+
host = os.path.dirname(urlparse(jobID).path).lstrip("/")
233+
return self._getJobOutputFilesOnHost(self.connections[host]["connection"], jobID, localDir)

0 commit comments

Comments
 (0)