Skip to content

Commit 17fced9

Browse files
committed
Add count operator
Signed-off-by: Paolo Di Tommaso <[email protected]>
1 parent 2b4c5b3 commit 17fced9

File tree

4 files changed

+109
-47
lines changed

4 files changed

+109
-47
lines changed

modules/nextflow/src/main/groovy/nextflow/extension/CollectOp.groovy

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ class CollectOp {
6262
Op.bind(processor, target, msg)
6363
}
6464

65-
DataflowHelper.subscribeImpl(source, events)
65+
DataflowHelper.subscribeImpl(source, true, events)
6666
return target
6767
}
6868

modules/nextflow/src/main/groovy/nextflow/extension/OperatorImpl.groovy

Lines changed: 13 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,7 @@ class OperatorImpl {
9191
* @param closure
9292
* @return
9393
*/
94+
@Deprecated
9495
DataflowWriteChannel chain(final DataflowReadChannel<?> source, final Closure closure) {
9596
final target = CH.createBy(source)
9697
newOperator(source, target, stopErrorListener(source,target), new ChainWithClosure(closure))
@@ -104,6 +105,7 @@ class OperatorImpl {
104105
* @param closure
105106
* @return
106107
*/
108+
@Deprecated
107109
DataflowWriteChannel chain(final DataflowReadChannel<?> source, final Map<String, Object> params, final Closure closure) {
108110
return ChainOp.create()
109111
.withSource(source)
@@ -423,8 +425,7 @@ class OperatorImpl {
423425
* @return
424426
*/
425427
DataflowWriteChannel count(final DataflowReadChannel source ) {
426-
final target = count0(source, null)
427-
return target
428+
return count0(source, null)
428429
}
429430

430431
/**
@@ -435,33 +436,25 @@ class OperatorImpl {
435436
* @return
436437
*/
437438
DataflowWriteChannel count(final DataflowReadChannel source, final Object criteria ) {
438-
final target = count0(source, criteria)
439-
return target
439+
return count0(source, criteria)
440440
}
441441

442442
private static DataflowVariable count0(DataflowReadChannel<?> source, Object criteria) {
443443

444444
final target = new DataflowVariable()
445445
final discriminator = criteria != null ? new BooleanReturningMethodInvoker("isCase") : null
446446

447-
if( source instanceof DataflowExpression) {
448-
source.whenBound { item ->
449-
discriminator == null || discriminator.invoke(criteria, item) ? target.bind(1) : target.bind(0)
450-
}
451-
}
452-
else {
453-
final action = { current, item ->
454-
discriminator == null || discriminator.invoke(criteria, item) ? current+1 : current
455-
}
456-
457-
ReduceOp .create()
458-
.withSource(source)
459-
.withTarget(target)
460-
.withSeed(0)
461-
.withAction(action)
462-
.apply()
447+
final action = { current, item ->
448+
discriminator == null || discriminator.invoke(criteria, item) ? current+1 : current
463449
}
464450

451+
ReduceOp .create()
452+
.withSource(source)
453+
.withTarget(target)
454+
.withSeed(0)
455+
.withAction(action)
456+
.apply()
457+
465458
return target
466459
}
467460

modules/nextflow/src/main/groovy/nextflow/extension/ReduceOp.groovy

Lines changed: 17 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -17,17 +17,19 @@
1717

1818
package nextflow.extension
1919

20+
import static nextflow.extension.DataflowHelper.*
21+
2022
import groovy.transform.CompileStatic
2123
import groovy.util.logging.Slf4j
2224
import groovyx.gpars.dataflow.DataflowReadChannel
2325
import groovyx.gpars.dataflow.DataflowVariable
26+
import groovyx.gpars.dataflow.expression.DataflowExpression
2427
import groovyx.gpars.dataflow.operator.DataflowEventAdapter
2528
import groovyx.gpars.dataflow.operator.DataflowProcessor
2629
import nextflow.Channel
2730
import nextflow.Global
2831
import nextflow.Session
2932
import nextflow.extension.op.Op
30-
3133
/**
3234
* Implements reduce operator logic
3335
*
@@ -88,30 +90,12 @@ class ReduceOp {
8890
throw new IllegalArgumentException("Missing reduce operator source channel")
8991
if( target==null )
9092
target = new DataflowVariable()
91-
93+
final stopOnFirst = source instanceof DataflowExpression
9294
// the *accumulator* value
9395
def accum = this.seed
9496

9597
// intercepts operator events
9698
final listener = new DataflowEventAdapter() {
97-
/*
98-
* call the passed closure each time
99-
*/
100-
void afterRun(final DataflowProcessor processor, final List<Object> messages) {
101-
final item = Op.unwrap(messages).get(0)
102-
final value = accum == null ? item : action.call(accum, item)
103-
104-
if( value == Channel.VOID ) {
105-
// do nothing
106-
}
107-
else if( value == Channel.STOP ) {
108-
processor.terminate()
109-
}
110-
else {
111-
accum = value
112-
}
113-
}
114-
11599
/*
116100
* when terminates bind the result value
117101
*/
@@ -129,13 +113,20 @@ class ReduceOp {
129113
}
130114
}
131115

132-
ChainOp.create()
133-
.withSource(source)
134-
.withTarget(CH.create())
135-
.withListener(listener)
116+
final parameters = new OpParams()
117+
.withInput(source)
136118
.withAccumulator(true)
137-
.withAction({true})
138-
.apply()
119+
.withListener(listener)
120+
121+
newOperator(parameters) {
122+
final value = accum == null ? it : action.call(accum, it)
123+
final proc = getDelegate() as DataflowProcessor
124+
if( value!=Channel.VOID && value!=Channel.STOP ) {
125+
accum = value
126+
}
127+
if( stopOnFirst || value==Channel.STOP )
128+
proc.terminate()
129+
}
139130

140131
return target
141132
}

modules/nextflow/src/test/groovy/nextflow/prov/ProvTest.groovy

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -451,6 +451,84 @@ class ProvTest extends Dsl2Spec {
451451
def upstream1 = upstreamTasksOf('p2')
452452
upstream1.size() == 1
453453
upstream1.first.name == 'p1 (5)'
454+
}
455+
456+
def 'should track provenance with collect operator'() {
457+
when:
458+
dsl_eval(globalConfig(), '''
459+
workflow {
460+
channel.of(1,2,3) | p1 | collect | p2
461+
}
462+
463+
process p1 {
464+
input: val(x)
465+
output: val(y)
466+
exec:
467+
y = x
468+
}
469+
470+
process p2 {
471+
input: val(x)
472+
exec:
473+
println x
474+
}
475+
''')
454476

477+
then:
478+
def t1 = upstreamTasksOf('p2')
479+
t1.name == ['p1 (1)', 'p1 (2)', 'p1 (3)']
455480
}
481+
482+
def 'should track provenance with count value operator'() {
483+
when:
484+
dsl_eval(globalConfig(), '''
485+
workflow {
486+
channel.value(1) | p1 | count | p2
487+
}
488+
489+
process p1 {
490+
input: val(x)
491+
output: val(y)
492+
exec:
493+
y = x
494+
}
495+
496+
process p2 {
497+
input: val(x)
498+
exec:
499+
println x
500+
}
501+
''')
502+
503+
then:
504+
def t1 = upstreamTasksOf('p2')
505+
t1.name == ['p1']
506+
}
507+
508+
def 'should track provenance with count many operator'() {
509+
when:
510+
dsl_eval(globalConfig(), '''
511+
workflow {
512+
channel.of(1,2,3) | p1 | count | p2
513+
}
514+
515+
process p1 {
516+
input: val(x)
517+
output: val(y)
518+
exec:
519+
y = x
520+
}
521+
522+
process p2 {
523+
input: val(x)
524+
exec:
525+
println x
526+
}
527+
''')
528+
529+
then:
530+
def t1 = upstreamTasksOf('p2')
531+
t1.name == ['p1 (1)', 'p1 (2)', 'p1 (3)']
532+
}
533+
456534
}

0 commit comments

Comments
 (0)